@@ -406,6 +406,8 @@ struct BlockwiseGemmAccelRewritePattern
406406 int64_t kpackPerBlock = tuningParams.getKpackPerBlock ();
407407 int64_t mPerWave = tuningParams.getMPerWave ();
408408 int64_t nPerWave = tuningParams.getNPerWave ();
409+ bool loadAFromLDS = adaptor.getLoadAfromLDS ();
410+ bool loadBFromLDS = adaptor.getLoadBfromLDS ();
409411
410412 Type bufferElemTypeA =
411413 cast<MemRefType>(adaptor.getMatrixA ().getType ()).getElementType ();
@@ -445,6 +447,9 @@ struct BlockwiseGemmAccelRewritePattern
445447 << " nRepeat: " << nRepeats << " \n "
446448 << " kBasePerThread: " << kBasePerThread << " \n "
447449 << " kpackPerBlock: " << kpackPerBlock << " \n "
450+ << " loadAFromLDS: " << loadAFromLDS << " \n "
451+ << " loadBFromLDS: " << loadBFromLDS << " \n "
452+ << " rotateMWithK: " << op.getRotateMWithK () << " \n "
448453 << " bufferA type: " << adaptor.getBufferA ().getType () << " \n "
449454 << " bufferB type: " << adaptor.getBufferB ().getType () << " \n " );
450455
@@ -463,44 +468,75 @@ struct BlockwiseGemmAccelRewritePattern
463468 // considered a temporary hack until we have a proper way of "searching"
464469 // through different schedules (either heuristically or automatically)
465470
466- Value wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad (
467- b, loc, op.getMatrixA (), op.getBlockSize (), op.getInMPerThread (), " m" ,
468- op.getRotateMWithK ());
469- Value wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad (
470- b, loc, op.getMatrixB (), op.getBlockSize (), op.getInNPerThread (), " n" ,
471- op.getRotateNWithK ());
471+ Value wrappedLDSBufferForLoadA, wrappedLDSBufferForLoadB;
472+ if (loadAFromLDS) {
473+ wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad (
474+ b, loc, op.getMatrixA (), op.getBlockSize (), op.getInMPerThread (), " m" ,
475+ op.getRotateMWithK (), op.getSplitKAcrossThreadsFirstA ());
476+ }
477+ if (loadBFromLDS) {
478+ wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad (
479+ b, loc, op.getMatrixB (), op.getBlockSize (), op.getInNPerThread (), " n" ,
480+ op.getRotateNWithK (), op.getSplitKAcrossThreadsFirstA ());
481+ }
472482
473483 auto mLoop = b.create <affine::AffineForOp>(loc, 0 , mRepeats );
474484 {
475485 OpBuilder::InsertionGuard guard (b);
476486 b.setInsertionPointToStart (mLoop .getBody ());
477487 Value i = mLoop .getInductionVar ();
478488
479- // regsA = read A from LDS
480- b.create <ThreadwiseReadIntoOp>(
481- loc, wrappedLDSBufferForLoadA, op.getBufferA (), b.getArrayAttr ({}),
482- ValueRange{tid, i}, /* forceUnroll=*/ true , /* useIndexDiffs=*/ true );
489+ Value bufferA = adaptor.getBufferA ();
490+ if (loadAFromLDS) {
491+ // regsA = read A from LDS
492+ b.create <ThreadwiseReadIntoOp>(
493+ loc, wrappedLDSBufferForLoadA, bufferA, b.getArrayAttr ({}),
494+ ValueRange{tid, i}, /* forceUnroll=*/ true , /* useIndexDiffs=*/ true );
495+ } else {
496+ if (cast<ShapedType>(bufferA.getType ()).getRank () == 1 ) {
497+ BottomUpTMBuilder regsBuilder (b, {" mk" }, {mRepeats * kBasePerThread },
498+ loc);
499+ regsBuilder.unmerge ({" iidx" , " k" }, {0 , 1 }, " mk" ,
500+ {mRepeats , kBasePerThread });
501+ bufferA =
502+ rock::transform (b, bufferA, b.getArrayAttr ({regsBuilder.get ()}));
503+ }
504+ bufferA = rock::createSliceOfFirstDim (b, loc, bufferA, i);
505+ }
506+ Value viewA =
507+ accelEmitterPtr->generateThreadwiseViewBufferA (b, loc, bufferA);
483508
484509 auto nLoop = b.create <affine::AffineForOp>(loc, 0 , nRepeats);
485510 {
486511 OpBuilder::InsertionGuard guard (b);
487512 b.setInsertionPointToStart (nLoop.getBody ());
488513 Value j = nLoop.getInductionVar ();
489514
490- // regsB = read B from LDS
491- b.create <ThreadwiseReadIntoOp>(
492- loc, wrappedLDSBufferForLoadB, op.getBufferB (), b.getArrayAttr ({}),
493- ValueRange{tid, j}, /* forceUnroll=*/ true , /* useIndexDiffs=*/ true );
515+ Value bufferB = adaptor.getBufferB ();
516+ if (loadBFromLDS) {
517+ // regsB = read B from LDS
518+ b.create <ThreadwiseReadIntoOp>(
519+ loc, wrappedLDSBufferForLoadB, bufferB, b.getArrayAttr ({}),
520+ ValueRange{tid, j}, /* forceUnroll=*/ true , /* useIndexDiffs=*/ true );
521+ } else {
522+ if (cast<ShapedType>(bufferB.getType ()).getRank () == 1 ) {
523+ BottomUpTMBuilder regsBBuilder (b, {" nk" },
524+ {nRepeats * kBasePerThread }, loc);
525+ regsBBuilder.unmerge ({" jidx" , " k" }, {0 , 1 }, " nk" ,
526+ {nRepeats, kBasePerThread });
527+ bufferB = rock::transform (b, bufferB,
528+ b.getArrayAttr ({regsBBuilder.get ()}));
529+ }
530+ bufferB = rock::createSliceOfFirstDim (b, loc, bufferB, j);
531+ }
532+ Value viewB =
533+ accelEmitterPtr->generateThreadwiseViewBufferB (b, loc, bufferB);
494534
495535 // regsC += regsA * regsB
496536 auto kLoop = b.create <affine::AffineForOp>(loc, 0 , kBasePerThread );
497537 {
498538 OpBuilder::InsertionGuard guard (b);
499539 b.setInsertionPointToStart (kLoop .getBody ());
500- Value viewA = accelEmitterPtr->generateThreadwiseViewBufferA (
501- b, loc, adaptor.getBufferA ());
502- Value viewB = accelEmitterPtr->generateThreadwiseViewBufferB (
503- b, loc, adaptor.getBufferB ());
504540 Value viewC = accelEmitterPtr->generateThreadwiseViewBufferC (
505541 b, loc, adaptor.getMatrixC ());
506542 Value k = kLoop .getInductionVar ();
0 commit comments