@@ -30,6 +30,65 @@ using namespace special_ticks;
3030// / and default memory space.
3131static bool isMemRefTypeOk (MemRefType type) { return type.hasStaticShape (); }
3232
33+ static inline int64_t getSizeInBytes (MemRefType &memType) {
34+ // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
35+ // least have a large enough size for i1
36+ int64_t size = memType.getElementTypeBitWidth () / 8 ;
37+ size = (size > 0 ) ? size : 1 ;
38+ for (auto v : memType.getShape ()) {
39+ size *= v;
40+ }
41+ return size;
42+ }
43+
44+ static bool needsHoistOutOfParallelLoop (Operation *op) {
45+ Operation *parent =
46+ op->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
47+ if (isa_and_nonnull<scf::ForallOp>(parent)) {
48+ // check if the current allocation is between the nested pfor, and use
49+ // inside the inner parallel loop
50+ SmallVector<Operation *, 4 > parallelOpInCurBlock;
51+ Block *curBlock = op->getBlock ();
52+ for (auto &curOp : curBlock->getOperations ()) {
53+ if (isa<scf::ForallOp>(curOp)) {
54+ parallelOpInCurBlock.push_back (&curOp);
55+ }
56+ }
57+
58+ if (parallelOpInCurBlock.empty ())
59+ return false ;
60+
61+ for (auto *use : op->getUsers ()) {
62+ for (auto *parallelOp : parallelOpInCurBlock) {
63+ if (parallelOp->isAncestor (use)) {
64+ return true ;
65+ }
66+ }
67+ }
68+ }
69+
70+ return false ;
71+ }
72+
73+ static bool isForallLoopBoundStatic (Operation *op) {
74+ auto forallOp = dyn_cast<scf::ForallOp>(op);
75+ if (!forallOp)
76+ return false ;
77+
78+ auto lbs = forallOp.getMixedLowerBound ();
79+ auto ubs = forallOp.getMixedUpperBound ();
80+ auto steps = forallOp.getMixedStep ();
81+ auto allConstantValue = [](SmallVector<OpFoldResult> vals) -> bool {
82+ return llvm::all_of (vals, [](OpFoldResult val) {
83+ std::optional<int64_t > const_val = getConstantIntValue (val);
84+ return const_val.has_value ();
85+ });
86+ };
87+
88+ return allConstantValue (lbs) && allConstantValue (ubs) &&
89+ allConstantValue (steps);
90+ }
91+
3392void Tick::update (int64_t tick) {
3493 if (tick == UNTRACEABLE_ACCESS) {
3594 firstAccess = UNTRACEABLE_ACCESS;
@@ -180,28 +239,60 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
180239// trait, and is not scf.for
181240Operation *TickCollecter::getAllocScope (TickCollecterStates *s,
182241 Operation *op) const {
183- auto parent = op;
242+ Operation *parent = op;
243+ bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop (op);
244+
184245 for (;;) {
185246 parent = parent->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
186247 if (!parent) {
187248 return nullptr ;
188249 }
189- if (!isa<scf::ForOp>(parent)) {
190- return parent;
191- }
250+
251+ if (isa<scf::ForOp>(parent))
252+ continue ;
253+
254+ if (isa<scf::ForallOp>(parent) &&
255+ (moveToUpperParellelLoop && isForallLoopBoundStatic (parent)))
256+ continue ;
257+
258+ return parent;
192259 }
193260}
194261
195262FailureOr<size_t > TickCollecter::getAllocSize (TickCollecterStates *s,
196263 Operation *op) const {
197264 auto refType = cast<MemRefType>(op->getResultTypes ().front ());
198- int64_t size = refType.getElementTypeBitWidth () / 8 ;
199- // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
200- // least have a large enough size for i1
201- size = (size != 0 ) ? size : 1 ;
202- for (auto v : refType.getShape ()) {
203- size *= v;
265+
266+ // Get the total number of threads from the outermost to the current level of
267+ // the parallel loop that the allocation located in.
268+ int64_t numThreads = 1 ;
269+ if (needsHoistOutOfParallelLoop (op)) {
270+ Operation *parent =
271+ op->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
272+ while (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
273+ if (!isForallLoopBoundStatic (forallOp))
274+ break ;
275+
276+ OpBuilder builder{forallOp->getContext ()};
277+ std::optional<int64_t > numIterations;
278+ for (auto [lb, ub, step] : llvm::zip (forallOp.getLowerBound (builder),
279+ forallOp.getUpperBound (builder),
280+ forallOp.getStep (builder))) {
281+ numIterations = constantTripCount (lb, ub, step);
282+ if (numIterations.has_value ()) {
283+ numThreads *= numIterations.value ();
284+ } else {
285+ return op->emitError (" Expecting static loop range!" );
286+ }
287+ }
288+
289+ parent = parent->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
290+ }
204291 }
292+ assert (numThreads > 0 );
293+
294+ int64_t size = getSizeInBytes (refType);
295+ size *= numThreads;
205296 if (size > 0 ) {
206297 return static_cast <size_t >(size);
207298 }
@@ -391,11 +482,113 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
391482 Value mergedAlloc,
392483 int64_t byteOffset) const {
393484 builder.setInsertionPoint (origAllocOp);
394- auto byteShift =
395- builder.create <arith::ConstantIndexOp>(origAllocOp->getLoc (), byteOffset);
396- return builder.create <memref::ViewOp>(origAllocOp->getLoc (),
397- origAllocOp->getResultTypes ().front (),
398- mergedAlloc, byteShift, ValueRange{});
485+ auto loc = origAllocOp->getLoc ();
486+ auto byteShift = builder.create <arith::ConstantIndexOp>(loc, byteOffset);
487+
488+ bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop (origAllocOp);
489+ Operation *parent =
490+ origAllocOp->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
491+ if (!moveToUpperParellelLoop || !parent || !isa<scf::ForallOp>(parent))
492+ return builder.create <memref::ViewOp>(loc,
493+ origAllocOp->getResultTypes ().front (),
494+ mergedAlloc, byteShift, ValueRange{});
495+
496+ // get the aggregated inductorVar
497+ Value inductVar;
498+ bool isOuterMostLoop = true ;
499+ int64_t innerLoopUpperBound = 1 ;
500+ while (parent) {
501+ if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
502+ if (isForallLoopBoundStatic (forallOp)) {
503+ SmallVector<Value> ubs = forallOp.getUpperBound (builder);
504+ SmallVector<Value> lbs = forallOp.getLowerBound (builder);
505+ SmallVector<Value> steps = forallOp.getStep (builder);
506+ SmallVector<Value> inductionVars = forallOp.getInductionVars ();
507+
508+ auto getCurrentVar = [&loc, &builder](Value var, Value lb,
509+ Value step) -> Value {
510+ if (!isConstantIntValue (lb, 0 ))
511+ var = builder.create <arith::SubIOp>(loc, var, lb);
512+
513+ if (!isConstantIntValue (step, 1 ))
514+ var = builder.create <arith::DivSIOp>(loc, var, step);
515+ return var;
516+ };
517+
518+ auto getAggregatedVar =
519+ [&loc, &builder, &getCurrentVar](
520+ const SmallVector<Value> &_lbs, const SmallVector<Value> &_ubs,
521+ const SmallVector<Value> &_steps,
522+ const SmallVector<Value> &_inductVars) -> Value {
523+ Value var;
524+ if (_ubs.size () == 1 ) {
525+ var = getCurrentVar (_inductVars[0 ], _lbs[0 ], _steps[0 ]);
526+ return var;
527+ } else {
528+ bool isFirstLoop = true ;
529+ for (auto [lb, ub, step, inductVar] :
530+ llvm::zip (_lbs, _ubs, _steps, _inductVars)) {
531+ if (isFirstLoop) {
532+ var = getCurrentVar (inductVar, lb, step);
533+ isFirstLoop = false ;
534+ } else {
535+ Value cur_var = getCurrentVar (inductVar, lb, step);
536+ std::optional<int64_t > bound = constantTripCount (lb, ub, step);
537+ assert (bound.has_value ());
538+ Value boundVal =
539+ builder.create <arith::ConstantIndexOp>(loc, bound.value ());
540+ Value tmpVal =
541+ builder.create <arith::MulIOp>(loc, var, boundVal);
542+ var = builder.create <arith::AddIOp>(loc, tmpVal, cur_var);
543+ }
544+ }
545+ return var;
546+ }
547+ };
548+
549+ if (isOuterMostLoop) {
550+ inductVar = getAggregatedVar (lbs, ubs, steps, inductionVars);
551+ isOuterMostLoop = false ;
552+ } else {
553+ Value currentVar = getAggregatedVar (lbs, ubs, steps, inductionVars);
554+
555+ Value innerLoopBoundVal =
556+ builder.create <arith::ConstantIndexOp>(loc, innerLoopUpperBound);
557+ Value intermediateVal =
558+ builder.create <arith::MulIOp>(loc, currentVar, innerLoopBoundVal);
559+ inductVar =
560+ builder.create <arith::AddIOp>(loc, inductVar, intermediateVal);
561+ }
562+ // get aggregated loop bound
563+ for (auto [lb, ub, step] : llvm::zip (lbs, ubs, steps)) {
564+ std::optional<int64_t > cur_bound = constantTripCount (lb, ub, step);
565+ assert (cur_bound.has_value ());
566+ innerLoopUpperBound *= cur_bound.value ();
567+ }
568+ }
569+ }
570+
571+ parent = parent->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
572+ }
573+
574+ if (!isOuterMostLoop) {
575+ // get original shape size
576+ auto memType = cast<MemRefType>(origAllocOp->getResultTypes ().front ());
577+ int64_t size = getSizeInBytes (memType);
578+ Value origSize = builder.create <arith::ConstantIndexOp>(loc, size);
579+ Value offsetPerThread =
580+ builder.create <arith::MulIOp>(loc, inductVar, origSize);
581+ Value byteShiftPerThread =
582+ builder.create <arith::AddIOp>(loc, byteShift, offsetPerThread);
583+
584+ return builder.create <memref::ViewOp>(
585+ loc, origAllocOp->getResultTypes ().front (), mergedAlloc,
586+ byteShiftPerThread, ValueRange{});
587+ } else {
588+ return builder.create <memref::ViewOp>(loc,
589+ origAllocOp->getResultTypes ().front (),
590+ mergedAlloc, byteShift, ValueRange{});
591+ }
399592}
400593
401594LogicalResult
0 commit comments