@@ -62,6 +62,7 @@ struct ConvertLayoutOpConversion
6262 ArrayRef<unsigned > origRepShape,
6363 ArrayRef<unsigned > outOrd, SmallVector<Value> &vals,
6464 Value smemBase) const {
65+ auto b = TritonLLVMOpBuilder (loc, rewriter);
6566 auto accumNumCTAsEachRep = product<unsigned >(numCTAsEachRep);
6667 auto layout = type.getEncoding ();
6768 auto rank = type.getRank ();
@@ -110,29 +111,29 @@ struct ConvertLayoutOpConversion
110111 Value offset = LLVM::linearize (rewriter, loc, multiDimOffsetWrapped,
111112 paddedRepShape, outOrd);
112113 auto elemPtrTy = smemBase.getType ();
113- Value ptr = gep (elemPtrTy, llvmElemTy, smemBase, offset);
114+ Value ptr = b. gep (elemPtrTy, llvmElemTy, smemBase, offset);
114115 auto vecTy = vec_ty (llvmElemTy, vec);
115116 if (stNotRd) {
116- Value valVec = undef (vecTy);
117+ Value valVec = b. undef (vecTy);
117118 for (unsigned v = 0 ; v < vec; ++v) {
118119 auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
119120 if (isInt1)
120- currVal = zext (llvmElemTy, currVal);
121+ currVal = b. zext (llvmElemTy, currVal);
121122 else if (isPtr)
122- currVal = ptrtoint (llvmElemTy, currVal);
123- valVec = insert_element (vecTy, valVec, currVal, i32_val (v));
123+ currVal = b. ptrtoint (llvmElemTy, currVal);
124+ valVec = b. insert_element (vecTy, valVec, currVal, b. i32_val (v));
124125 }
125- store (valVec, ptr);
126+ b. store (valVec, ptr);
126127 } else {
127- Value valVec = load (vecTy, ptr);
128+ Value valVec = b. load (vecTy, ptr);
128129 for (unsigned v = 0 ; v < vec; ++v) {
129- Value currVal = extract_element (llvmElemTy, valVec, i32_val (v));
130+ Value currVal = b. extract_element (llvmElemTy, valVec, b. i32_val (v));
130131 if (isInt1)
131- currVal = icmp_ne (currVal,
132- rewriter.create <LLVM::ConstantOp>(
133- loc, i8_ty, rewriter.getI8IntegerAttr (0 )));
132+ currVal = b. icmp_ne (
133+ currVal, rewriter.create <LLVM::ConstantOp>(
134+ loc, i8_ty, rewriter.getI8IntegerAttr (0 )));
134135 else if (isPtr)
135- currVal = inttoptr (llvmElemTyOrig, currVal);
136+ currVal = b. inttoptr (llvmElemTyOrig, currVal);
136137 vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
137138 }
138139 }
@@ -146,6 +147,7 @@ struct ConvertLayoutOpConversion
146147 ConversionPatternRewriter &rewriter,
147148 const TargetInfoBase &targetInfo) const {
148149 auto loc = op.getLoc ();
150+ auto b = TritonLLVMOpBuilder (loc, rewriter);
149151 auto typeConverter = getTypeConverter ();
150152 RankedTensorType srcTy = op.getSrc ().getType ();
151153 RankedTensorType dstTy = op.getType ();
@@ -205,12 +207,12 @@ struct ConvertLayoutOpConversion
205207 auto multiDimRepId =
206208 getMultiDimIndex<unsigned >(repId, numReplicates, outOrd);
207209 if (repId != 0 ) {
208- barrier ();
210+ b. barrier ();
209211 }
210212 processReplica (loc, rewriter, /* stNotRd*/ true , srcTy, inNumCTAsEachRep,
211213 multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
212214 vals, smemBase);
213- barrier ();
215+ b. barrier ();
214216 processReplica (loc, rewriter, /* stNotRd*/ false , dstTy, outNumCTAsEachRep,
215217 multiDimRepId, outVec, paddedRepShape, origRepShape,
216218 outOrd, outVals, smemBase);
@@ -355,6 +357,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
355357 ConversionPatternRewriter &rewriter) const {
356358 MLIRContext *ctx = op.getContext ();
357359 auto loc = op.getLoc ();
360+ auto b = TritonLLVMOpBuilder (loc, rewriter);
358361 auto srcTy = op.getSrc ().getType ();
359362 auto dstTy = op.getType ();
360363
@@ -399,9 +402,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
399402 // Munge input values
400403 for (const auto &it : llvm::enumerate (inVals)) {
401404 if (isSubByteInt) {
402- inVals[it.index ()] = zext (llvmElemTy, it.value ());
405+ inVals[it.index ()] = b. zext (llvmElemTy, it.value ());
403406 } else if (isPtr) {
404- inVals[it.index ()] = ptrtoint (llvmElemTy, it.value ());
407+ inVals[it.index ()] = b. ptrtoint (llvmElemTy, it.value ());
405408 }
406409 }
407410
@@ -417,9 +420,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
417420 // Unmunge output values
418421 for (const auto &it : llvm::enumerate (outVals)) {
419422 if (isSubByteInt) {
420- outVals[it.index ()] = trunc (llvmElemTyOrig, it.value ());
423+ outVals[it.index ()] = b. trunc (llvmElemTyOrig, it.value ());
421424 } else if (isPtr) {
422- outVals[it.index ()] = inttoptr (llvmElemTyOrig, it.value ());
425+ outVals[it.index ()] = b. inttoptr (llvmElemTyOrig, it.value ());
423426 }
424427 }
425428
@@ -443,6 +446,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
443446 ConversionPatternRewriter &rewriter) const {
444447 MLIRContext *ctx = op.getContext ();
445448 auto loc = op.getLoc ();
449+ auto b = TritonLLVMOpBuilder (loc, rewriter);
446450
447451 StringAttr kRegister = str_attr (" register" );
448452 StringAttr kLane = str_attr (" lane" );
@@ -452,9 +456,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
452456 StringAttr kIteration = str_attr (" iteration" );
453457
454458 Value threadId = getThreadId (rewriter, loc);
455- Value threadsPerWarp = i32_val (srcLayout.getInDimSize (kLane ));
456- Value laneId = urem (threadId, threadsPerWarp);
457- Value warpId = udiv (threadId, threadsPerWarp);
459+ Value threadsPerWarp = b. i32_val (srcLayout.getInDimSize (kLane ));
460+ Value laneId = b. urem (threadId, threadsPerWarp);
461+ Value warpId = b. udiv (threadId, threadsPerWarp);
458462
459463 auto scratchConfig =
460464 getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
@@ -541,37 +545,38 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
541545 {kWarp , 0 },
542546 {kBlock , 0 }})[0 ]
543547 .second ;
544- Value offset = xor_ (regBase, i32_val (regIdx));
548+ Value offset = b. xor_ (regBase, b. i32_val (regIdx));
545549 if (paddedSize > 0 ) {
546550 assert (llvm::isPowerOf2_32 (paddedStride));
547551 assert (llvm::isPowerOf2_32 (paddedSize));
548552 auto rshiftVal = llvm::Log2_32 (paddedStride);
549553 auto lshiftVal = llvm::Log2_32 (paddedSize);
550- offset = add (shl (lshr (offset, i32_val (rshiftVal)), i32_val (lshiftVal)),
551- offset);
554+ offset = b.add (
555+ b.shl (b.lshr (offset, b.i32_val (rshiftVal)), b.i32_val (lshiftVal)),
556+ offset);
552557 }
553- auto vecAddr = gep (sharedPtrTy, elemTy, smemBase, offset);
558+ auto vecAddr = b. gep (sharedPtrTy, elemTy, smemBase, offset);
554559 vecAddr.setInbounds (true );
555560 return vecAddr;
556561 };
557562
558563 auto storeBase = applyLinearLayout (loc, rewriter, shmemStoreLayout,
559- {{kRegister , i32_val (0 )},
564+ {{kRegister , b. i32_val (0 )},
560565 {kLane , laneId},
561566 {kWarp , warpId},
562- {kBlock , i32_val (0 )}})[0 ]
567+ {kBlock , b. i32_val (0 )}})[0 ]
563568 .second ;
564569 auto loadBase = applyLinearLayout (loc, rewriter, shmemLoadLayout,
565- {{kRegister , i32_val (0 )},
570+ {{kRegister , b. i32_val (0 )},
566571 {kLane , laneId},
567572 {kWarp , warpId},
568- {kBlock , i32_val (0 )}})[0 ]
573+ {kBlock , b. i32_val (0 )}})[0 ]
569574 .second ;
570575 // register idx -> Value
571576 llvm::MapVector<int , Value> outVals;
572577 for (int i = 0 ; i < iterations; i++) {
573578 if (i != 0 )
574- barrier ();
579+ b. barrier ();
575580
576581 auto &inRegs = inRegsForIter[i];
577582 auto &outRegs = outRegsForIter[i];
@@ -591,19 +596,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
591596 targetInfo.storeMatrixShared (rewriter, loc, vecAddr, valsVec);
592597 } else {
593598 targetInfo.storeDShared (rewriter, loc, vecAddr, std::nullopt , valsVec,
594- /* pred=*/ true_val ());
599+ /* pred=*/ b. true_val ());
595600 }
596601 }
597602
598- barrier ();
603+ b. barrier ();
599604
600605 for (int j = 0 ; j < outSize / iterations; j += scratchConfig.outVec ) {
601606 auto outRegSlice = outRegs[j];
602607 auto vecAddr = getVecAddr (shmemLoadLayout, loadBase, outRegSlice);
603608 Value valsVec =
604609 targetInfo.loadDShared (rewriter, loc, vecAddr, std::nullopt ,
605610 vec_ty (elemTy, scratchConfig.outVec ),
606- /* pred=*/ true_val ());
611+ /* pred=*/ b. true_val ());
607612 for (Value v : unpackLLVector (loc, valsVec, rewriter))
608613 outVals[outRegSlice++] = v;
609614 }
@@ -646,6 +651,7 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
646651 ConversionPatternRewriter &rewriter) const {
647652 MLIRContext *ctx = op.getContext ();
648653 Location loc = op.getLoc ();
654+ auto b = TritonLLVMOpBuilder (loc, rewriter);
649655 StringAttr kRegister = str_attr (" register" );
650656 StringAttr kLane = str_attr (" lane" );
651657 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
@@ -657,8 +663,8 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
657663 SmallVector<Value> shflOuts (Cp.getInDimSize (kRegister ));
658664
659665 Value threadId = getThreadId (rewriter, loc);
660- Value threadsPerWarp = i32_val (Cp.getInDimSize (kLane ));
661- Value laneId = urem (threadId, threadsPerWarp);
666+ Value threadsPerWarp = b. i32_val (Cp.getInDimSize (kLane ));
667+ Value laneId = b. urem (threadId, threadsPerWarp);
662668
663669 // Emit one shuffle per destination register.
664670 for (int i : llvm::seq (shflOuts.size ())) {
@@ -667,22 +673,22 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
667673 // At the same time, for each register, P1 returns the source value index
668674 // to provide as the shuffle value.
669675 auto out = applyLinearLayout (loc, rewriter, P1,
670- {{kLane , laneId}, {kRegister , i32_val (i)}});
676+ {{kLane , laneId}, {kRegister , b. i32_val (i)}});
671677 assert (out.size () == 1 );
672678 Value srcRegIdx = out.front ().second ;
673679 // The size of the input lane dimension is the number of selects to emit.
674680 // TODO(jeff): For dtypes smaller than i32, we can use byte permutes and
675681 // shuffle multiple values at a time.
676- Value shflSrc = undef (srcValues.front ().getType ());
682+ Value shflSrc = b. undef (srcValues.front ().getType ());
677683 for (int j : llvm::seq (reducedP1.getInDimSize (kLane ))) {
678684 int32_t check =
679685 reducedP1.apply ({{kLane , j}, {kRegister , i}}).front ().second ;
680- shflSrc =
681- select ( icmp_eq (srcRegIdx, i32_val (check)), srcValues[check], shflSrc);
686+ shflSrc = b. select (b. icmp_eq (srcRegIdx, b. i32_val (check)),
687+ srcValues[check], shflSrc);
682688 }
683689
684690 out = applyLinearLayout (loc, rewriter, Cp,
685- {{kLane , laneId}, {kRegister , i32_val (i)}});
691+ {{kLane , laneId}, {kRegister , b. i32_val (i)}});
686692 assert (out.size () == 1 );
687693 Value shflIdx = out.front ().second ;
688694 shflOuts[i] = targetInfo.shuffleIdx (rewriter, loc, shflSrc, shflIdx);
@@ -693,16 +699,16 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
693699 // selects.
694700 SmallVector<Value> results (shflOuts.size ());
695701 for (int i : llvm::seq (results.size ())) {
696- Value result = undef (srcValues.front ().getType ());
702+ Value result = b. undef (srcValues.front ().getType ());
697703
698704 auto out = applyLinearLayout (loc, rewriter, P2inv,
699- {{kLane , laneId}, {kRegister , i32_val (i)}});
705+ {{kLane , laneId}, {kRegister , b. i32_val (i)}});
700706 Value resultIdx = out.front ().second ;
701707 for (int j : llvm::seq (reducedP2.getInDimSize (kLane ))) {
702708 int32_t check =
703709 reducedP2.apply ({{kLane , j}, {kRegister , i}}).front ().second ;
704- result =
705- select ( icmp_eq (resultIdx, i32_val (check)), shflOuts[check], result);
710+ result = b. select (b. icmp_eq (resultIdx, b. i32_val (check)), shflOuts[check],
711+ result);
706712 }
707713 results[i] = result;
708714 }
0 commit comments