@@ -527,16 +527,142 @@ void OpenACCRecipeBuilderBase::createFirstprivateRecipeCopy(
527527// doesn't restore it aftewards.
528528void OpenACCRecipeBuilderBase::createReductionRecipeCombiner (
529529 mlir::Location loc, mlir::Location locEnd, mlir::Value mainOp,
530- mlir::acc::ReductionRecipeOp recipe, size_t numBounds) {
530+ mlir::acc::ReductionRecipeOp recipe, size_t numBounds, QualType origType,
531+ llvm::ArrayRef<OpenACCReductionRecipe::CombinerRecipe> combinerRecipes) {
531532 mlir::Block *block =
532533 createRecipeBlock (recipe.getCombinerRegion (), mainOp.getType (), loc,
533534 numBounds, /* isInit=*/ false );
534535 builder.setInsertionPointToEnd (&recipe.getCombinerRegion ().back ());
535536 CIRGenFunction::LexicalScope ls (cgf, loc, block);
536537
537- mlir::BlockArgument lhsArg = block->getArgument (0 );
538+ mlir::Value lhsArg = block->getArgument (0 );
539+ mlir::Value rhsArg = block->getArgument (1 );
540+ llvm::MutableArrayRef<mlir::BlockArgument> boundsRange =
541+ block->getArguments ().drop_front (2 );
542+
543+ if (llvm::any_of (combinerRecipes, [](auto &r) { return r.Op == nullptr ; })) {
544+ cgf.cgm .errorNYI (loc, " OpenACC Reduction combiner not generated" );
545+ mlir::acc::YieldOp::create (builder, locEnd, block->getArgument (0 ));
546+ return ;
547+ }
548+
549+ // apply the bounds so that we can get our bounds emitted correctly.
550+ for (mlir::BlockArgument boundArg : llvm::reverse (boundsRange))
551+ std::tie (lhsArg, rhsArg) =
552+ createBoundsLoop (lhsArg, rhsArg, boundArg, loc, /* inverse=*/ false );
553+
554+ // Emitter for when we know this isn't a struct or array we have to loop
555+ // through. This should work for the 'field' once the get-element call has
556+ // been made.
557+ auto emitSingleCombiner =
558+ [&](mlir::Value lhsArg, mlir::Value rhsArg,
559+ const OpenACCReductionRecipe::CombinerRecipe &combiner) {
560+ mlir::Type elementTy =
561+ mlir::cast<cir::PointerType>(lhsArg.getType ()).getPointee ();
562+ CIRGenFunction::DeclMapRevertingRAII declMapRAIILhs{cgf, combiner.LHS };
563+ cgf.setAddrOfLocalVar (
564+ combiner.LHS , Address{lhsArg, elementTy,
565+ cgf.getContext ().getDeclAlign (combiner.LHS )});
566+ CIRGenFunction::DeclMapRevertingRAII declMapRAIIRhs{cgf, combiner.RHS };
567+ cgf.setAddrOfLocalVar (
568+ combiner.RHS , Address{rhsArg, elementTy,
569+ cgf.getContext ().getDeclAlign (combiner.RHS )});
570+
571+ [[maybe_unused]] mlir::LogicalResult stmtRes =
572+ cgf.emitStmt (combiner.Op , /* useCurrentScope=*/ true );
573+ };
574+
575+ // Emitter for when we know this is either a non-array or element of an array
576+ // (which also shouldn't be an array type?). This function should generate the
577+ // initialization code for an entire 'array-element'/non-array, including
578+ // diving into each element of a struct (if necessary).
579+ auto emitCombiner = [&](mlir::Value lhsArg, mlir::Value rhsArg, QualType ty) {
580+ assert (!ty->isArrayType () && " Array type shouldn't get here" );
581+ if (const auto *rd = ty->getAsRecordDecl ()) {
582+ if (combinerRecipes.size () == 1 &&
583+ cgf.getContext ().hasSameType (ty, combinerRecipes[0 ].LHS ->getType ())) {
584+ // If this is a 'top level' operator on the type we can just emit this
585+ // as a simple one.
586+ emitSingleCombiner (lhsArg, rhsArg, combinerRecipes[0 ]);
587+ } else {
588+ // else we have to handle each individual field after after a
589+ // get-element.
590+ for (const auto &[field, combiner] :
591+ llvm::zip_equal (rd->fields (), combinerRecipes)) {
592+ mlir::Type fieldType = cgf.convertType (field->getType ());
593+ auto fieldPtr = cir::PointerType::get (fieldType);
594+
595+ mlir::Value lhsField = builder.createGetMember (
596+ loc, fieldPtr, lhsArg, field->getName (), field->getFieldIndex ());
597+ mlir::Value rhsField = builder.createGetMember (
598+ loc, fieldPtr, rhsArg, field->getName (), field->getFieldIndex ());
599+
600+ emitSingleCombiner (lhsField, rhsField, combiner);
601+ }
602+ }
603+
604+ } else {
605+ // if this is a single-thing (because we should know this isn't an array,
606+ // as Sema wouldn't let us get here), we can just do a normal emit call.
607+ emitSingleCombiner (lhsArg, rhsArg, combinerRecipes[0 ]);
608+ }
609+ };
610+
611+ if (const auto *cat = cgf.getContext ().getAsConstantArrayType (origType)) {
612+ // If we're in an array, we have to emit the combiner for each element of
613+ // the array.
614+ auto itrTy = mlir::cast<cir::IntType>(cgf.PtrDiffTy );
615+ auto itrPtrTy = cir::PointerType::get (itrTy);
616+
617+ mlir::Value zero =
618+ builder.getConstInt (loc, mlir::cast<cir::IntType>(cgf.PtrDiffTy ), 0 );
619+ mlir::Value itr =
620+ cir::AllocaOp::create (builder, loc, itrPtrTy, itrTy, " itr" ,
621+ cgf.cgm .getSize (cgf.getPointerAlign ()));
622+ builder.CIRBaseBuilderTy ::createStore (loc, zero, itr);
623+
624+ builder.setInsertionPointAfter (builder.createFor (
625+ loc,
626+ /* condBuilder=*/
627+ [&](mlir::OpBuilder &b, mlir::Location loc) {
628+ auto loadItr = cir::LoadOp::create (builder, loc, {itr});
629+ mlir::Value arraySize = builder.getConstInt (
630+ loc, mlir::cast<cir::IntType>(cgf.PtrDiffTy ), cat->getZExtSize ());
631+ auto cmp = builder.createCompare (loc, cir::CmpOpKind::lt, loadItr,
632+ arraySize);
633+ builder.createCondition (cmp);
634+ },
635+ /* bodyBuilder=*/
636+ [&](mlir::OpBuilder &b, mlir::Location loc) {
637+ auto loadItr = cir::LoadOp::create (builder, loc, {itr});
638+ auto lhsElt = builder.getArrayElement (
639+ loc, loc, lhsArg, cgf.convertType (cat->getElementType ()), loadItr,
640+ /* shouldDecay=*/ true );
641+ auto rhsElt = builder.getArrayElement (
642+ loc, loc, rhsArg, cgf.convertType (cat->getElementType ()), loadItr,
643+ /* shouldDecay=*/ true );
644+
645+ emitCombiner (lhsElt, rhsElt, cat->getElementType ());
646+ builder.createYield (loc);
647+ },
648+ /* stepBuilder=*/
649+ [&](mlir::OpBuilder &b, mlir::Location loc) {
650+ auto loadItr = cir::LoadOp::create (builder, loc, {itr});
651+ auto inc = cir::UnaryOp::create (builder, loc, loadItr.getType (),
652+ cir::UnaryOpKind::Inc, loadItr);
653+ builder.CIRBaseBuilderTy ::createStore (loc, inc, itr);
654+ builder.createYield (loc);
655+ }));
538656
539- mlir::acc::YieldOp::create (builder, locEnd, lhsArg);
657+ } else if (origType->isArrayType ()) {
658+ cgf.cgm .errorNYI (loc,
659+ " OpenACC Reduction combiner non-constant array recipe" );
660+ } else {
661+ emitCombiner (lhsArg, rhsArg, origType);
662+ }
663+
664+ builder.setInsertionPointToEnd (&recipe.getCombinerRegion ().back ());
665+ mlir::acc::YieldOp::create (builder, locEnd, block->getArgument (0 ));
540666}
541667
542668} // namespace clang::CIRGen
0 commit comments