@@ -527,16 +527,140 @@ void OpenACCRecipeBuilderBase::createFirstprivateRecipeCopy(
527527// doesn't restore it aftewards.
528528void OpenACCRecipeBuilderBase::createReductionRecipeCombiner (
529529 mlir::Location loc, mlir::Location locEnd, mlir::Value mainOp,
530- mlir::acc::ReductionRecipeOp recipe, size_t numBounds) {
530+ mlir::acc::ReductionRecipeOp recipe, size_t numBounds, QualType origType,
531+ llvm::ArrayRef<OpenACCReductionRecipe::CombinerRecipe> combinerRecipes) {
531532 mlir::Block *block =
532533 createRecipeBlock (recipe.getCombinerRegion (), mainOp.getType (), loc,
533534 numBounds, /* isInit=*/ false );
534535 builder.setInsertionPointToEnd (&recipe.getCombinerRegion ().back ());
535536 CIRGenFunction::LexicalScope ls (cgf, loc, block);
536537
537- mlir::BlockArgument lhsArg = block->getArgument (0 );
538+ mlir::Value lhsArg = block->getArgument (0 );
539+ mlir::Value rhsArg = block->getArgument (1 );
540+ llvm::MutableArrayRef<mlir::BlockArgument> boundsRange =
541+ block->getArguments ().drop_front (2 );
542+
543+ if (llvm::any_of (combinerRecipes, [](auto &r) { return r.Op == nullptr ; })) {
544+ cgf.cgm .errorNYI (loc, " OpenACC Reduction combiner not generated" );
545+ mlir::acc::YieldOp::create (builder, locEnd, block->getArgument (0 ));
546+ return ;
547+ }
548+
549+ // apply the bounds so that we can get our bounds emitted correctly.
550+ for (mlir::BlockArgument boundArg : llvm::reverse (boundsRange))
551+ std::tie (lhsArg, rhsArg) =
552+ createBoundsLoop (lhsArg, rhsArg, boundArg, loc, /* inverse=*/ false );
553+
554+ // Emitter for when we know this isn't a struct or array we have to loop
555+ // through. This should work for the 'field' once the get-element call has
556+ // been made.
557+ auto emitSingleCombiner =
558+ [&](mlir::Value lhsArg, mlir::Value rhsArg,
559+ const OpenACCReductionRecipe::CombinerRecipe &combiner) {
560+ mlir::Type elementTy =
561+ mlir::cast<cir::PointerType>(lhsArg.getType ()).getPointee ();
562+ CIRGenFunction::DeclMapRevertingRAII declMapRAIILhs{cgf, combiner.LHS };
563+ cgf.setAddrOfLocalVar (
564+ combiner.LHS , Address{lhsArg, elementTy,
565+ cgf.getContext ().getDeclAlign (combiner.LHS )});
566+ CIRGenFunction::DeclMapRevertingRAII declMapRAIIRhs{cgf, combiner.RHS };
567+ cgf.setAddrOfLocalVar (
568+ combiner.RHS , Address{rhsArg, elementTy,
569+ cgf.getContext ().getDeclAlign (combiner.RHS )});
570+
571+ [[maybe_unused]] mlir::LogicalResult stmtRes =
572+ cgf.emitStmt (combiner.Op , /* useCurrentScope=*/ true );
573+ };
574+
575+ // Emitter for when we know this is either a non-array or element of an array
576+ // (which also shouldn't be an array type?). This function should generate the
577+ // loop to do this on each individual array or struct element (if necessary).
578+ auto emitCombiner = [&](mlir::Value lhsArg, mlir::Value rhsArg, QualType Ty) {
579+ if (const auto *RD = Ty->getAsRecordDecl ()) {
580+ if (combinerRecipes.size () == 1 &&
581+ cgf.getContext ().hasSameType (Ty, combinerRecipes[0 ].LHS ->getType ())) {
582+ // If this is a 'top level' operator on the type we can just emit this
583+ // as a simple one.
584+ emitSingleCombiner (lhsArg, rhsArg, combinerRecipes[0 ]);
585+ } else {
586+ // else we have to handle each individual field after after a
587+ // get-element.
588+ for (const auto &[field, combiner] :
589+ llvm::zip_equal (RD->fields (), combinerRecipes)) {
590+ mlir::Type fieldType = cgf.convertType (field->getType ());
591+ auto fieldPtr = cir::PointerType::get (fieldType);
592+
593+ mlir::Value lhsField = builder.createGetMember (
594+ loc, fieldPtr, lhsArg, field->getName (), field->getFieldIndex ());
595+ mlir::Value rhsField = builder.createGetMember (
596+ loc, fieldPtr, rhsArg, field->getName (), field->getFieldIndex ());
597+
598+ emitSingleCombiner (lhsField, rhsField, combiner);
599+ }
600+ }
601+
602+ } else {
603+ // if this is a single-thing (because we should know this isn't an array,
604+ // as Sema wouldn't let us get here), we can just do a normal emit call.
605+ emitSingleCombiner (lhsArg, rhsArg, combinerRecipes[0 ]);
606+ }
607+ };
608+
609+ if (const auto *cat = cgf.getContext ().getAsConstantArrayType (origType)) {
610+ // If we're in an array, we have to emit the combiner for each element of
611+ // the array.
612+ auto itrTy = mlir::cast<cir::IntType>(cgf.PtrDiffTy );
613+ auto itrPtrTy = cir::PointerType::get (itrTy);
614+
615+ mlir::Value zero =
616+ builder.getConstInt (loc, mlir::cast<cir::IntType>(cgf.PtrDiffTy ), 0 );
617+ mlir::Value itr =
618+ cir::AllocaOp::create (builder, loc, itrPtrTy, itrTy, " itr" ,
619+ cgf.cgm .getSize (cgf.getPointerAlign ()));
620+ builder.CIRBaseBuilderTy ::createStore (loc, zero, itr);
621+
622+ builder.setInsertionPointAfter (builder.createFor (
623+ loc,
624+ /* condBuilder=*/
625+ [&](mlir::OpBuilder &b, mlir::Location loc) {
626+ auto loadItr = cir::LoadOp::create (builder, loc, {itr});
627+ mlir::Value arraySize = builder.getConstInt (
628+ loc, mlir::cast<cir::IntType>(cgf.PtrDiffTy ), cat->getZExtSize ());
629+ auto cmp = builder.createCompare (loc, cir::CmpOpKind::lt, loadItr,
630+ arraySize);
631+ builder.createCondition (cmp);
632+ },
633+ /* bodyBuilder=*/
634+ [&](mlir::OpBuilder &b, mlir::Location loc) {
635+ auto loadItr = cir::LoadOp::create (builder, loc, {itr});
636+ auto lhsElt = builder.getArrayElement (
637+ loc, loc, lhsArg, cgf.convertType (cat->getElementType ()), loadItr,
638+ /* shouldDecay=*/ true );
639+ auto rhsElt = builder.getArrayElement (
640+ loc, loc, rhsArg, cgf.convertType (cat->getElementType ()), loadItr,
641+ /* shouldDecay=*/ true );
642+
643+ emitCombiner (lhsElt, rhsElt, cat->getElementType ());
644+ builder.createYield (loc);
645+ },
646+ /* stepBuilder=*/
647+ [&](mlir::OpBuilder &b, mlir::Location loc) {
648+ auto loadItr = cir::LoadOp::create (builder, loc, {itr});
649+ auto inc = cir::UnaryOp::create (builder, loc, loadItr.getType (),
650+ cir::UnaryOpKind::Inc, loadItr);
651+ builder.CIRBaseBuilderTy ::createStore (loc, inc, itr);
652+ builder.createYield (loc);
653+ }));
538654
539- mlir::acc::YieldOp::create (builder, locEnd, lhsArg);
655+ } else if (origType->isArrayType ()) {
656+ cgf.cgm .errorNYI (loc,
657+ " OpenACC Reduction combiner non-constant array recipe" );
658+ } else {
659+ emitCombiner (lhsArg, rhsArg, origType);
660+ }
661+
662+ builder.setInsertionPointToEnd (&recipe.getCombinerRegion ().back ());
663+ mlir::acc::YieldOp::create (builder, locEnd, block->getArgument (0 ));
540664}
541665
542666} // namespace clang::CIRGen
0 commit comments