@@ -4961,12 +4961,18 @@ static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
4961
4961
template <typename OpTy>
4962
4962
static uint64_t getReductionDataSize (OpTy &op) {
4963
4963
if (op.getNumReductionVars () > 0 ) {
4964
- assert (op.getNumReductionVars () == 1 &&
4965
- " Only 1 reduction variable currently supported" );
4966
- mlir::Type reductionVarTy = op.getReductionVars ()[0 ].getType ();
4964
+ SmallVector<omp::DeclareReductionOp> reductions;
4965
+ collectReductionDecls (op, reductions);
4966
+
4967
+ llvm::SmallVector<mlir::Type> members;
4968
+ for (omp::DeclareReductionOp &red : reductions) {
4969
+ members.push_back (red.getType ());
4970
+ }
4967
4971
Operation *opp = op.getOperation ();
4972
+ auto structType = mlir::LLVM::LLVMStructType::getLiteral (
4973
+ opp->getContext (), members, /* isPacked=*/ false );
4968
4974
DataLayout dl = DataLayout (opp->getParentOfType <ModuleOp>());
4969
- return getTypeByteSize (reductionVarTy , dl);
4975
+ return getTypeByteSize (structType , dl);
4970
4976
}
4971
4977
return 0 ;
4972
4978
}
@@ -5062,8 +5068,6 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
5062
5068
(maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5063
5069
combinedMaxThreadsVal = maxThreadsVal;
5064
5070
5065
- // Calculate reduction data size, limited to single reduction variable for
5066
- // now.
5067
5071
int32_t reductionDataSize = 0 ;
5068
5072
if (isGPU && capturedOp) {
5069
5073
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
0 commit comments