@@ -367,12 +367,13 @@ Value getSmemVecAddrNEW(const LinearLayout ®Layout,
367367 // solution for all swizzled shared memory scenarios, including the edge case
368368 // mentioned above.
369369 if (isSimpleSharedMemoryAccess (shape, allocShape, sharedEnc)) { // Case 1
370- smemOffset = applyLinearLayout (loc, rewriter, regToSharedLayout,
370+ auto res = applyLinearLayout (loc, rewriter, regToSharedLayout,
371371 {{kRegister , regId},
372372 {kLane , laneId},
373373 {kWarp , warpId},
374- {kBlock , blockId}})[0 ]
375- .second ;
374+ {kBlock , blockId}});
375+ std::cout << " linearLayRes.size(): " << res.size () << " \n " ;
376+ smemOffset = res[0 ].second ;
376377 } else { // Case 2 -> rank-reduced swizzling
377378 assert (rank >= 2 && " Swizzling only applies to tensors with rank >= 2" );
378379 assert (!sharedEnc.getHasLeadingOffset () &&
@@ -426,7 +427,7 @@ Value getSmemVecAddrNEW(const LinearLayout ®Layout,
426427} // namespace
427428
428429
429- bool getBoolFromEnv (const std::string& envVar, bool defaultValue = false ) {
430+ static bool getBoolFromEnv (const std::string& envVar, bool defaultValue = false ) {
430431 const char * value = std::getenv (envVar.c_str ());
431432 if (value == nullptr ) {
432433 return defaultValue; // Return default if the variable is not set
@@ -549,10 +550,18 @@ bool emitTransferBetweenRegistersAndSharedNEW(
549550 StringAttr kWarp = str_attr (" warp" );
550551
551552 auto shape = sharedTy.getShape ();
553+ llvm::dbgs () << " registerTy enc\n " ;
554+ registerTy.dump ();
555+ registerTy.getEncoding ().dump ();
556+ llvm::dbgs () << " shape: " ; for (auto &el : shape) { llvm::dbgs () << el << " " ;} llvm::dbgs () << " \n " ;
552557 LinearLayout regLayout =
553558 triton::gpu::toLinearLayout (shape, registerTy.getEncoding ());
554559 printLinearThing (regLayout, " regLayout" );
555560
561+ llvm::dbgs () << " sharedTy enc\n " ;
562+ sharedTy.dump ();
563+ sharedTy.getEncoding ().dump ();
564+ llvm::dbgs () << " shape: " ; for (auto &el : shape) { llvm::dbgs () << el << " " ;} llvm::dbgs () << " \n " ;
556565 LinearLayout sharedLayout = triton::gpu::toLinearLayout (
557566 shape, sharedTy.getEncoding (), elemLlvmTy.getIntOrFloatBitWidth ());
558567 printLinearThing (sharedLayout, " sharedLayout" );
@@ -653,13 +662,30 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
653662 bool success = emitTransferBetweenRegistersAndShared (
654663 dstTy, srcTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt , smemObj, loc,
655664 rewriter, target, [&](VectorType vecTy, Value vecAddr) {
656- auto vecVal = load (vecTy, vecAddr);
657- vecVal.setAlignment (vecTy.getNumElements () *
658- elemLlvmTy.getIntOrFloatBitWidth () / 8 );
659-
660- for (int v = 0 ; v < vecTy.getNumElements (); v++) {
661- ret.push_back (extract_element (elemLlvmTy, vecVal, i32_val (v)));
665+ if (vecTy.getNumElements () >= 64 ) {
666+ assert (vecTy.getNumElements () % 64 == 0 );
667+ for (int i = 0 ; i < vecTy.getNumElements (); i+=64 ) {
668+ auto smallVecTy = vec_ty (elemLlvmTy, 64 );
669+ auto vecAddrNew = gep (vecAddr.getType (), i32_ty, vecAddr, SmallVector<Value>({i32_val (i)}));
670+ auto vecVal = load (smallVecTy, vecAddrNew);
671+ vecVal.setAlignment (smallVecTy.getNumElements () *
672+ elemLlvmTy.getIntOrFloatBitWidth () / 8 );
673+
674+ for (int v = 0 ; v < 64 ; v++) {
675+ ret.push_back (extract_element (elemLlvmTy, vecVal, i32_val (v)));
676+ }
677+ }
678+
679+ } else {
680+ auto vecVal = load (vecTy, vecAddr);
681+ vecVal.setAlignment (vecTy.getNumElements () *
682+ elemLlvmTy.getIntOrFloatBitWidth () / 8 );
683+
684+ for (int v = 0 ; v < vecTy.getNumElements (); v++) {
685+ ret.push_back (extract_element (elemLlvmTy, vecVal, i32_val (v)));
686+ }
662687 }
688+
663689 });
664690 if (!success)
665691 llvm::report_fatal_error (" Failed to emit transfer from shared to register" );
0 commit comments