1313#include " ReductionProcessor.h"
1414
1515#include " flang/Lower/AbstractConverter.h"
16+ #include " flang/Optimizer/Builder/HLFIRTools.h"
1617#include " flang/Optimizer/Builder/Todo.h"
1718#include " flang/Optimizer/Dialect/FIRType.h"
1819#include " flang/Optimizer/HLFIR/HLFIROps.h"
@@ -90,10 +91,42 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name,
9091 if (isByRef)
9192 byrefAddition = " _byref" ;
9293
93- return (llvm::Twine (name) +
94- (ty.isIntOrIndex () ? llvm::Twine (" _i_" ) : llvm::Twine (" _f_" )) +
95- llvm::Twine (ty.getIntOrFloatBitWidth ()) + byrefAddition)
96- .str ();
94+ if (fir::isa_trivial (ty))
95+ return (llvm::Twine (name) +
96+ (ty.isIntOrIndex () ? llvm::Twine (" _i_" ) : llvm::Twine (" _f_" )) +
97+ llvm::Twine (ty.getIntOrFloatBitWidth ()) + byrefAddition)
98+ .str ();
99+
100+ // creates a name like reduction_i_64_box_ux4x3
101+ if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
102+ // TODO: support for allocatable boxes:
103+ // !fir.box<!fir.heap<!fir.array<...>>>
104+ fir::SequenceType seqTy = fir::unwrapRefType (boxTy.getEleTy ())
105+ .dyn_cast_or_null <fir::SequenceType>();
106+ if (!seqTy)
107+ return {};
108+
109+ std::string prefix = getReductionName (
110+ name, fir::unwrapSeqOrBoxedSeqType (ty), /* isByRef=*/ false );
111+ if (prefix.empty ())
112+ return {};
113+ std::stringstream tyStr;
114+ tyStr << prefix << " _box_" ;
115+ bool first = true ;
116+ for (std::int64_t extent : seqTy.getShape ()) {
117+ if (first)
118+ first = false ;
119+ else
120+ tyStr << " x" ;
121+ if (extent == seqTy.getUnknownExtent ())
122+ tyStr << ' u' ; // I'm not sure that '?' is safe in symbol names
123+ else
124+ tyStr << extent;
125+ }
126+ return (tyStr.str () + byrefAddition).str ();
127+ }
128+
129+ return {};
97130}
98131
99132std::string ReductionProcessor::getReductionName (
@@ -281,13 +314,158 @@ mlir::Value ReductionProcessor::createScalarCombiner(
281314 return reductionOp;
282315}
283316
317+ // / Create reduction combiner region for reduction variables which are boxed
318+ // / arrays
319+ static void genBoxCombiner (fir::FirOpBuilder &builder, mlir::Location loc,
320+ ReductionProcessor::ReductionIdentifier redId,
321+ fir::BaseBoxType boxTy, mlir::Value lhs,
322+ mlir::Value rhs) {
323+ fir::SequenceType seqTy =
324+ mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy ());
325+ // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
326+ if (!seqTy || seqTy.hasUnknownShape ())
327+ TODO (loc, " Unsupported boxed type in OpenMP reduction" );
328+
329+ // load fir.ref<fir.box<...>>
330+ mlir::Value lhsAddr = lhs;
331+ lhs = builder.create <fir::LoadOp>(loc, lhs);
332+ rhs = builder.create <fir::LoadOp>(loc, rhs);
333+
334+ const unsigned rank = seqTy.getDimension ();
335+ llvm::SmallVector<mlir::Value> extents;
336+ extents.reserve (rank);
337+ llvm::SmallVector<mlir::Value> lbAndExtents;
338+ lbAndExtents.reserve (rank * 2 );
339+
340+ // Get box lowerbounds and extents:
341+ mlir::Type idxTy = builder.getIndexType ();
342+ for (unsigned i = 0 ; i < rank; ++i) {
343+ // TODO: ideally we want to hoist box reads out of the critical section.
344+ // We could do this by having box dimensions in block arguments like
345+ // OpenACC does
346+ mlir::Value dim = builder.createIntegerConstant (loc, idxTy, i);
347+ auto dimInfo =
348+ builder.create <fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
349+ extents.push_back (dimInfo.getExtent ());
350+ lbAndExtents.push_back (dimInfo.getLowerBound ());
351+ lbAndExtents.push_back (dimInfo.getExtent ());
352+ }
353+
354+ auto shapeShiftTy = fir::ShapeShiftType::get (builder.getContext (), rank);
355+ auto shapeShift =
356+ builder.create <fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
357+
358+ // Iterate over array elements, applying the equivalent scalar reduction:
359+
360+ // A hlfir::elemental here gets inlined with a temporary so create the
361+ // loop nest directly.
362+ // This function already controls all of the code in this region so we
363+ // know this won't miss any opportuinties for clever elemental inlining
364+ hlfir::LoopNest nest =
365+ hlfir::genLoopNest (loc, builder, extents, /* isUnordered=*/ true );
366+ builder.setInsertionPointToStart (nest.innerLoop .getBody ());
367+ mlir::Type refTy = fir::ReferenceType::get (seqTy.getEleTy ());
368+ auto lhsEleAddr = builder.create <fir::ArrayCoorOp>(
369+ loc, refTy, lhs, shapeShift, /* slice=*/ mlir::Value{},
370+ nest.oneBasedIndices , /* typeparms=*/ mlir::ValueRange{});
371+ auto rhsEleAddr = builder.create <fir::ArrayCoorOp>(
372+ loc, refTy, rhs, shapeShift, /* slice=*/ mlir::Value{},
373+ nest.oneBasedIndices , /* typeparms=*/ mlir::ValueRange{});
374+ auto lhsEle = builder.create <fir::LoadOp>(loc, lhsEleAddr);
375+ auto rhsEle = builder.create <fir::LoadOp>(loc, rhsEleAddr);
376+ mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner (
377+ builder, loc, redId, refTy, lhsEle, rhsEle);
378+ builder.create <fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
379+
380+ builder.setInsertionPointAfter (nest.outerLoop );
381+ builder.create <mlir::omp::YieldOp>(loc, lhsAddr);
382+ }
383+
384+ // generate combiner region for reduction operations
385+ static void genCombiner (fir::FirOpBuilder &builder, mlir::Location loc,
386+ ReductionProcessor::ReductionIdentifier redId,
387+ mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
388+ bool isByRef) {
389+ ty = fir::unwrapRefType (ty);
390+
391+ if (fir::isa_trivial (ty)) {
392+ mlir::Value lhsLoaded = builder.loadIfRef (loc, lhs);
393+ mlir::Value rhsLoaded = builder.loadIfRef (loc, rhs);
394+
395+ mlir::Value result = ReductionProcessor::createScalarCombiner (
396+ builder, loc, redId, ty, lhsLoaded, rhsLoaded);
397+ if (isByRef) {
398+ builder.create <fir::StoreOp>(loc, result, lhs);
399+ builder.create <mlir::omp::YieldOp>(loc, lhs);
400+ } else {
401+ builder.create <mlir::omp::YieldOp>(loc, result);
402+ }
403+ return ;
404+ }
405+ // all arrays should have been boxed
406+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
407+ genBoxCombiner (builder, loc, redId, boxTy, lhs, rhs);
408+ return ;
409+ }
410+
411+ TODO (loc, " OpenMP genCombiner for unsupported reduction variable type" );
412+ }
413+
414+ static mlir::Value
415+ createReductionInitRegion (fir::FirOpBuilder &builder, mlir::Location loc,
416+ const ReductionProcessor::ReductionIdentifier redId,
417+ mlir::Type type, bool isByRef) {
418+ mlir::Type ty = fir::unwrapRefType (type);
419+ mlir::Value initValue = ReductionProcessor::getReductionInitValue (
420+ loc, fir::unwrapSeqOrBoxedSeqType (ty), redId, builder);
421+
422+ if (fir::isa_trivial (ty)) {
423+ if (isByRef) {
424+ mlir::Value alloca = builder.create <fir::AllocaOp>(loc, ty);
425+ builder.createStoreWithConvert (loc, initValue, alloca);
426+ return alloca;
427+ }
428+ // by val
429+ return initValue;
430+ }
431+
432+ // all arrays are boxed
433+ if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
434+ assert (isByRef && " passing arrays by value is unsupported" );
435+ // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
436+ mlir::Type innerTy = fir::extractSequenceType (boxTy);
437+ if (!mlir::isa<fir::SequenceType>(innerTy))
438+ TODO (loc, " Unsupported boxed type for reduction" );
439+ // Create the private copy from the initial fir.box:
440+ hlfir::Entity source = hlfir::Entity{builder.getBlock ()->getArgument (0 )};
441+
442+ // TODO: if the whole reduction is nested inside of a loop, this alloca
443+ // could lead to a stack overflow (the memory is only freed at the end of
444+ // the stack frame). The reduction declare operation needs a deallocation
445+ // region to undo the init region.
446+ hlfir::Entity temp = createStackTempFromMold (loc, builder, source);
447+
448+ // Put the temporary inside of a box:
449+ hlfir::Entity box = hlfir::genVariableBox (loc, builder, temp);
450+ builder.create <hlfir::AssignOp>(loc, initValue, box);
451+ mlir::Value boxAlloca = builder.create <fir::AllocaOp>(loc, ty);
452+ builder.create <fir::StoreOp>(loc, box, boxAlloca);
453+ return boxAlloca;
454+ }
455+
456+ TODO (loc, " createReductionInitRegion for unsupported type" );
457+ }
458+
284459mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl (
285460 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
286461 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
287462 bool isByRef) {
288463 mlir::OpBuilder::InsertionGuard guard (builder);
289464 mlir::ModuleOp module = builder.getModule ();
290465
466+ if (reductionOpName.empty ())
467+ TODO (loc, " Reduction of some types is not supported" );
468+
291469 auto decl =
292470 module .lookupSymbol <mlir::omp::ReductionDeclareOp>(reductionOpName);
293471 if (decl)
@@ -304,14 +482,9 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
304482 decl.getInitializerRegion ().end (), {type}, {loc});
305483 builder.setInsertionPointToEnd (&decl.getInitializerRegion ().back ());
306484
307- mlir::Value init = getReductionInitValue (loc, type, redId, builder);
308- if (isByRef) {
309- mlir::Value alloca = builder.create <fir::AllocaOp>(loc, valTy);
310- builder.createStoreWithConvert (loc, init, alloca);
311- builder.create <mlir::omp::YieldOp>(loc, alloca);
312- } else {
313- builder.create <mlir::omp::YieldOp>(loc, init);
314- }
485+ mlir::Value init =
486+ createReductionInitRegion (builder, loc, redId, type, isByRef);
487+ builder.create <mlir::omp::YieldOp>(loc, init);
315488
316489 builder.createBlock (&decl.getReductionRegion (),
317490 decl.getReductionRegion ().end (), {type, type},
@@ -320,19 +493,7 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
320493 builder.setInsertionPointToEnd (&decl.getReductionRegion ().back ());
321494 mlir::Value op1 = decl.getReductionRegion ().front ().getArgument (0 );
322495 mlir::Value op2 = decl.getReductionRegion ().front ().getArgument (1 );
323- mlir::Value outAddr = op1;
324-
325- op1 = builder.loadIfRef (loc, op1);
326- op2 = builder.loadIfRef (loc, op2);
327-
328- mlir::Value reductionOp =
329- createScalarCombiner (builder, loc, redId, type, op1, op2);
330- if (isByRef) {
331- builder.create <fir::StoreOp>(loc, reductionOp, outAddr);
332- builder.create <mlir::omp::YieldOp>(loc, outAddr);
333- } else {
334- builder.create <mlir::omp::YieldOp>(loc, reductionOp);
335- }
496+ genCombiner (builder, loc, redId, type, op1, op2, isByRef);
336497
337498 return decl;
338499}
@@ -387,13 +548,33 @@ void ReductionProcessor::addReductionDecl(
387548
388549 // initial pass to collect all reduction vars so we can figure out if this
389550 // should happen byref
551+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
390552 for (const Object &object : objectList) {
391553 const Fortran::semantics::Symbol *symbol = object.id ();
392554 if (reductionSymbols)
393555 reductionSymbols->push_back (symbol);
394556 mlir::Value symVal = converter.getSymbolAddress (*symbol);
395- if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
557+ auto redType = mlir::cast<fir::ReferenceType>(symVal.getType ());
558+
559+ // all arrays must be boxed so that we have convenient access to all the
560+ // information needed to iterate over the array
561+ if (mlir::isa<fir::SequenceType>(redType.getEleTy ())) {
562+ hlfir::Entity entity{symVal};
563+ entity = genVariableBox (currentLocation, builder, entity);
564+ mlir::Value box = entity.getBase ();
565+
566+ // Always pass the box by reference so that the OpenMP dialect
567+ // verifiers don't need to know anything about fir.box
568+ auto alloca =
569+ builder.create <fir::AllocaOp>(currentLocation, box.getType ());
570+ builder.create <fir::StoreOp>(currentLocation, box, alloca);
571+
572+ symVal = alloca;
573+ redType = mlir::cast<fir::ReferenceType>(symVal.getType ());
574+ } else if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>()) {
396575 symVal = declOp.getBase ();
576+ }
577+
397578 reductionVars.push_back (symVal);
398579 }
399580 const bool isByRef = doReductionByRef (reductionVars);
@@ -418,24 +599,17 @@ void ReductionProcessor::addReductionDecl(
418599 break ;
419600 }
420601
421- for (const Object &object : objectList) {
422- const Fortran::semantics::Symbol *symbol = object.id ();
423- mlir::Value symVal = converter.getSymbolAddress (*symbol);
424- if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
425- symVal = declOp.getBase ();
426- auto redType = symVal.getType ().cast <fir::ReferenceType>();
602+ for (mlir::Value symVal : reductionVars) {
603+ auto redType = mlir::cast<fir::ReferenceType>(symVal.getType ());
427604 if (redType.getEleTy ().isa <fir::LogicalType>())
428605 decl = createReductionDecl (
429606 firOpBuilder,
430607 getReductionName (intrinsicOp, firOpBuilder.getI1Type (), isByRef),
431608 redId, redType, currentLocation, isByRef);
432- else if (redType. getEleTy (). isIntOrIndexOrFloat ()) {
609+ else
433610 decl = createReductionDecl (
434611 firOpBuilder, getReductionName (intrinsicOp, redType, isByRef),
435612 redId, redType, currentLocation, isByRef);
436- } else {
437- TODO (currentLocation, " Reduction of some types is not supported" );
438- }
439613 reductionDeclSymbols.push_back (mlir::SymbolRefAttr::get (
440614 firOpBuilder.getContext (), decl.getSymName ()));
441615 }
@@ -452,8 +626,8 @@ void ReductionProcessor::addReductionDecl(
452626 if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
453627 symVal = declOp.getBase ();
454628 auto redType = symVal.getType ().cast <fir::ReferenceType>();
455- assert ( redType.getEleTy ().isIntOrIndexOrFloat () &&
456- " Unsupported reduction type" );
629+ if (! redType.getEleTy ().isIntOrIndexOrFloat ())
630+ TODO (currentLocation, " User Defined Reduction on non-trivial type" );
457631 decl = createReductionDecl (
458632 firOpBuilder,
459633 getReductionName (getRealName (*reductionIntrinsic).ToString (),
0 commit comments