@@ -4961,12 +4961,18 @@ static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
49614961template <typename OpTy>
49624962static uint64_t getReductionDataSize (OpTy &op) {
49634963 if (op.getNumReductionVars () > 0 ) {
4964- assert (op.getNumReductionVars () == 1 &&
4965- " Only 1 reduction variable currently supported" );
4966- mlir::Type reductionVarTy = op.getReductionVars ()[0 ].getType ();
4964+ SmallVector<omp::DeclareReductionOp> reductions;
4965+ collectReductionDecls (op, reductions);
4966+
4967+ llvm::SmallVector<mlir::Type> members;
4968+ for (omp::DeclareReductionOp &red : reductions) {
4969+ members.push_back (red.getType ());
4970+ }
49674971 Operation *opp = op.getOperation ();
4972+ auto structType = mlir::LLVM::LLVMStructType::getLiteral (
4973+ opp->getContext (), members, /* isPacked=*/ false );
49684974 DataLayout dl = DataLayout (opp->getParentOfType <ModuleOp>());
4969- return getTypeByteSize (reductionVarTy , dl);
4975+ return getTypeByteSize (structType , dl);
49704976 }
49714977 return 0 ;
49724978}
@@ -5062,8 +5068,6 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
50625068 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
50635069 combinedMaxThreadsVal = maxThreadsVal;
50645070
5065- // Calculate reduction data size, limited to single reduction variable for
5066- // now.
50675071 int32_t reductionDataSize = 0 ;
50685072 if (isGPU && capturedOp) {
50695073 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
0 commit comments