@@ -1247,10 +1247,10 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
12471247
12481248 // / Get the address of the type descriptor global variable that was created by
12491249 // / lowering for derived type \p recType.
1250- mlir::Value getTypeDescriptor (mlir::ModuleOp mod,
1251- mlir::ConversionPatternRewriter &rewriter,
1252- mlir::Location loc ,
1253- fir::RecordType recType) const {
1250+ template < typename ModOpTy>
1251+ mlir::Value
1252+ getTypeDescriptor (ModOpTy mod, mlir::ConversionPatternRewriter &rewriter ,
1253+ mlir::Location loc, fir::RecordType recType) const {
12541254 std::string name =
12551255 this ->options .typeDescriptorsRenamedForAssembly
12561256 ? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
@@ -1275,7 +1275,8 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
12751275 return rewriter.create <mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
12761276 }
12771277
1278- mlir::Value populateDescriptor (mlir::Location loc, mlir::ModuleOp mod,
1278+ template <typename ModOpTy>
1279+ mlir::Value populateDescriptor (mlir::Location loc, ModOpTy mod,
12791280 fir::BaseBoxType boxTy, mlir::Type inputType,
12801281 mlir::ConversionPatternRewriter &rewriter,
12811282 unsigned rank, mlir::Value eleSize,
@@ -1414,10 +1415,16 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
14141415 extraField =
14151416 this ->getExtraFromBox (loc, sourceBoxTyPair, sourceBox, rewriter);
14161417 }
1417- auto mod = box->template getParentOfType <mlir::ModuleOp>();
1418- mlir::Value descriptor =
1419- populateDescriptor (loc, mod, boxTy, inputType, rewriter, rank, eleSize,
1420- cfiTy, typeDesc, allocatorIdx, extraField);
1418+
1419+ mlir::Value descriptor;
1420+ if (auto gpuMod = box->template getParentOfType <mlir::gpu::GPUModuleOp>())
1421+ descriptor = populateDescriptor (loc, gpuMod, boxTy, inputType, rewriter,
1422+ rank, eleSize, cfiTy, typeDesc,
1423+ allocatorIdx, extraField);
1424+ else if (auto mod = box->template getParentOfType <mlir::ModuleOp>())
1425+ descriptor = populateDescriptor (loc, mod, boxTy, inputType, rewriter,
1426+ rank, eleSize, cfiTy, typeDesc,
1427+ allocatorIdx, extraField);
14211428
14221429 return {boxTy, descriptor, eleSize};
14231430 }
@@ -1460,11 +1467,17 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
14601467 mlir::Value extraField =
14611468 this ->getExtraFromBox (loc, inputBoxTyPair, loweredBox, rewriter);
14621469
1463- auto mod = box->template getParentOfType <mlir::ModuleOp>();
1464- mlir::Value descriptor =
1465- populateDescriptor (loc, mod, boxTy, box.getBox ().getType (), rewriter,
1466- rank, eleSize, cfiTy, typeDesc,
1467- /* allocatorIdx=*/ kDefaultAllocator , extraField);
1470+ mlir::Value descriptor;
1471+ if (auto gpuMod = box->template getParentOfType <mlir::gpu::GPUModuleOp>())
1472+ descriptor =
1473+ populateDescriptor (loc, gpuMod, boxTy, box.getBox ().getType (),
1474+ rewriter, rank, eleSize, cfiTy, typeDesc,
1475+ /* allocatorIdx=*/ kDefaultAllocator , extraField);
1476+ else if (auto mod = box->template getParentOfType <mlir::ModuleOp>())
1477+ descriptor =
1478+ populateDescriptor (loc, mod, boxTy, box.getBox ().getType (), rewriter,
1479+ rank, eleSize, cfiTy, typeDesc,
1480+ /* allocatorIdx=*/ kDefaultAllocator , extraField);
14681481
14691482 return {boxTy, descriptor, eleSize};
14701483 }
0 commit comments