@@ -1341,10 +1341,10 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
13411341
13421342 // / Get the address of the type descriptor global variable that was created by
13431343 // / lowering for derived type \p recType.
1344- mlir::Value getTypeDescriptor (mlir::ModuleOp mod,
1345- mlir::ConversionPatternRewriter &rewriter,
1346- mlir::Location loc ,
1347- fir::RecordType recType) const {
1344+ template < typename ModOpTy>
1345+ mlir::Value
1346+ getTypeDescriptor (ModOpTy mod, mlir::ConversionPatternRewriter &rewriter ,
1347+ mlir::Location loc, fir::RecordType recType) const {
13481348 std::string name =
13491349 this ->options .typeDescriptorsRenamedForAssembly
13501350 ? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
@@ -1369,7 +1369,8 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
13691369 return rewriter.create <mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
13701370 }
13711371
1372- mlir::Value populateDescriptor (mlir::Location loc, mlir::ModuleOp mod,
1372+ template <typename ModOpTy>
1373+ mlir::Value populateDescriptor (mlir::Location loc, ModOpTy mod,
13731374 fir::BaseBoxType boxTy, mlir::Type inputType,
13741375 mlir::ConversionPatternRewriter &rewriter,
13751376 unsigned rank, mlir::Value eleSize,
@@ -1508,10 +1509,16 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
15081509 extraField =
15091510 this ->getExtraFromBox (loc, sourceBoxTyPair, sourceBox, rewriter);
15101511 }
1511- auto mod = box->template getParentOfType <mlir::ModuleOp>();
1512- mlir::Value descriptor =
1513- populateDescriptor (loc, mod, boxTy, inputType, rewriter, rank, eleSize,
1514- cfiTy, typeDesc, allocatorIdx, extraField);
1512+
1513+ mlir::Value descriptor;
1514+ if (auto gpuMod = box->template getParentOfType <mlir::gpu::GPUModuleOp>())
1515+ descriptor = populateDescriptor (loc, gpuMod, boxTy, inputType, rewriter,
1516+ rank, eleSize, cfiTy, typeDesc,
1517+ allocatorIdx, extraField);
1518+ else if (auto mod = box->template getParentOfType <mlir::ModuleOp>())
1519+ descriptor = populateDescriptor (loc, mod, boxTy, inputType, rewriter,
1520+ rank, eleSize, cfiTy, typeDesc,
1521+ allocatorIdx, extraField);
15151522
15161523 return {boxTy, descriptor, eleSize};
15171524 }
@@ -1554,11 +1561,17 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
15541561 mlir::Value extraField =
15551562 this ->getExtraFromBox (loc, inputBoxTyPair, loweredBox, rewriter);
15561563
1557- auto mod = box->template getParentOfType <mlir::ModuleOp>();
1558- mlir::Value descriptor =
1559- populateDescriptor (loc, mod, boxTy, box.getBox ().getType (), rewriter,
1560- rank, eleSize, cfiTy, typeDesc,
1561- /* allocatorIdx=*/ kDefaultAllocator , extraField);
1564+ mlir::Value descriptor;
1565+ if (auto gpuMod = box->template getParentOfType <mlir::gpu::GPUModuleOp>())
1566+ descriptor =
1567+ populateDescriptor (loc, gpuMod, boxTy, box.getBox ().getType (),
1568+ rewriter, rank, eleSize, cfiTy, typeDesc,
1569+ /* allocatorIdx=*/ kDefaultAllocator , extraField);
1570+ else if (auto mod = box->template getParentOfType <mlir::ModuleOp>())
1571+ descriptor =
1572+ populateDescriptor (loc, mod, boxTy, box.getBox ().getType (), rewriter,
1573+ rank, eleSize, cfiTy, typeDesc,
1574+ /* allocatorIdx=*/ kDefaultAllocator , extraField);
15621575
15631576 return {boxTy, descriptor, eleSize};
15641577 }
0 commit comments