@@ -219,6 +219,29 @@ struct LowerGpuOpsToROCDLOpsPass
219219 gpu::GPUModuleOp m = getOperation ();
220220 MLIRContext *ctx = m.getContext ();
221221
222+ ArrayAttr targets = m.getTargetsAttr ();
223+ if (chipset == " infer" ) {
224+ if (!targets) {
225+ emitError (UnknownLoc::get (ctx),
226+ " ROCDLTargetAttr is empty on GPU module" );
227+ return signalPassFailure ();
228+ }
229+ if (targets.size () != 1 ) {
230+ emitError (UnknownLoc::get (ctx), " ROCDLTargetAttrs has more specified "
231+ " more than one gpu-arch on GPU module" );
232+ return signalPassFailure ();
233+ }
234+ const ROCDL::ROCDLTargetAttr targetAttr =
235+ mlir::dyn_cast<ROCDL::ROCDLTargetAttr>(targets.getValue ().front ());
236+ chipset = targetAttr.getChip ().str ();
237+ }
238+
239+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse (chipset);
240+ if (failed (maybeChipset)) {
241+ emitError (UnknownLoc::get (ctx), " Invalid chipset name: " + chipset);
242+ return signalPassFailure ();
243+ }
244+
222245 auto llvmDataLayout = m->getAttrOfType <StringAttr>(
223246 LLVM::LLVMDialect::getDataLayoutAttrName ());
224247 if (!llvmDataLayout) {
@@ -231,12 +254,6 @@ struct LowerGpuOpsToROCDLOpsPass
231254 UnitAttr::get (ctx));
232255 }
233256
234- FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse (chipset);
235- if (failed (maybeChipset)) {
236- emitError (UnknownLoc::get (ctx), " Invalid chipset name: " + chipset);
237- return signalPassFailure ();
238- }
239-
240257 // / Customize the bitwidth used for the device side index computations.
241258 LowerToLLVMOptions options (
242259 ctx, DataLayout (cast<DataLayoutOpInterface>(m.getOperation ())));
0 commit comments