@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
5050 assert (targetBits % sourceBits == 0 );
5151 Type type = srcIdx.getType ();
5252 IntegerAttr idxAttr = builder.getIntegerAttr (type, targetBits / sourceBits);
53- auto idx = builder.create <spirv::ConstantOp>(loc, type, idxAttr);
53+ auto idx = builder.createOrFold <spirv::ConstantOp>(loc, type, idxAttr);
5454 IntegerAttr srcBitsAttr = builder.getIntegerAttr (type, sourceBits);
55- auto srcBitsValue = builder.create <spirv::ConstantOp>(loc, type, srcBitsAttr);
56- auto m = builder.create <spirv::UModOp>(loc, srcIdx, idx);
57- return builder.create <spirv::IMulOp>(loc, type, m, srcBitsValue);
55+ auto srcBitsValue =
56+ builder.createOrFold <spirv::ConstantOp>(loc, type, srcBitsAttr);
57+ auto m = builder.createOrFold <spirv::UModOp>(loc, srcIdx, idx);
58+ return builder.createOrFold <spirv::IMulOp>(loc, type, m, srcBitsValue);
5859}
5960
6061// / Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
7475 Value lastDim = op->getOperand (op.getNumOperands () - 1 );
7576 Type type = lastDim.getType ();
7677 IntegerAttr attr = builder.getIntegerAttr (type, targetBits / sourceBits);
77- auto idx = builder.create <spirv::ConstantOp>(loc, type, attr);
78+ auto idx = builder.createOrFold <spirv::ConstantOp>(loc, type, attr);
7879 auto indices = llvm::to_vector<4 >(op.getIndices ());
7980 // There are two elements if this is a 1-D tensor.
8081 assert (indices.size () == 2 );
81- indices.back () = builder.create <spirv::SDivOp>(loc, lastDim, idx);
82+ indices.back () = builder.createOrFold <spirv::SDivOp>(loc, lastDim, idx);
8283 Type t = typeConverter.convertType (op.getComponentPtr ().getType ());
8384 return builder.create <spirv::AccessChainOp>(loc, t, op.getBasePtr (), indices);
8485}
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
9192 return srcBool;
9293 Value zero = spirv::ConstantOp::getZero (dstType, loc, builder);
9394 Value one = spirv::ConstantOp::getOne (dstType, loc, builder);
94- return builder.create <spirv::SelectOp>(loc, dstType, srcBool, one, zero);
95+ return builder.createOrFold <spirv::SelectOp>(loc, dstType, srcBool, one,
96+ zero);
9597}
9698
9799// / Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
111113 loc, builder.getIntegerType (targetBits), value);
112114 }
113115
114- value = builder.create <spirv::BitwiseAndOp>(loc, value, mask);
116+ value = builder.createOrFold <spirv::BitwiseAndOp>(loc, value, mask);
115117 }
116- return builder.create <spirv::ShiftLeftLogicalOp>(loc, value.getType (), value ,
117- offset);
118+ return builder.createOrFold <spirv::ShiftLeftLogicalOp>(loc, value.getType (),
119+ value, offset);
118120}
119121
120122// / Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165167 return srcInt;
166168
167169 auto one = spirv::ConstantOp::getOne (srcInt.getType (), loc, builder);
168- return builder.create <spirv::IEqualOp>(loc, srcInt, one);
170+ return builder.createOrFold <spirv::IEqualOp>(loc, srcInt, one);
169171}
170172
171173// ===----------------------------------------------------------------------===//
@@ -597,25 +599,26 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
597599 // ____XXXX________ -> ____________XXXX
598600 Value lastDim = accessChainOp->getOperand (accessChainOp.getNumOperands () - 1 );
599601 Value offset = getOffsetForBitwidth (loc, lastDim, srcBits, dstBits, rewriter);
600- Value result = rewriter.create <spirv::ShiftRightArithmeticOp>(
602+ Value result = rewriter.createOrFold <spirv::ShiftRightArithmeticOp>(
601603 loc, spvLoadOp.getType (), spvLoadOp, offset);
602604
603605 // Apply the mask to extract corresponding bits.
604- Value mask = rewriter.create <spirv::ConstantOp>(
606+ Value mask = rewriter.createOrFold <spirv::ConstantOp>(
605607 loc, dstType, rewriter.getIntegerAttr (dstType, (1 << srcBits) - 1 ));
606- result = rewriter.create <spirv::BitwiseAndOp>(loc, dstType, result, mask);
608+ result =
609+ rewriter.createOrFold <spirv::BitwiseAndOp>(loc, dstType, result, mask);
607610
608611 // Apply sign extension on the loading value unconditionally. The signedness
609612 // semantic is carried in the operator itself, we relies other pattern to
610613 // handle the casting.
611614 IntegerAttr shiftValueAttr =
612615 rewriter.getIntegerAttr (dstType, dstBits - srcBits);
613616 Value shiftValue =
614- rewriter.create <spirv::ConstantOp>(loc, dstType, shiftValueAttr);
615- result = rewriter.create <spirv::ShiftLeftLogicalOp>(loc, dstType, result ,
616- shiftValue);
617- result = rewriter.create <spirv::ShiftRightArithmeticOp>(loc, dstType, result,
618- shiftValue);
617+ rewriter.createOrFold <spirv::ConstantOp>(loc, dstType, shiftValueAttr);
618+ result = rewriter.createOrFold <spirv::ShiftLeftLogicalOp>(loc, dstType,
619+ result, shiftValue);
620+ result = rewriter.createOrFold <spirv::ShiftRightArithmeticOp>(
621+ loc, dstType, result, shiftValue);
619622
620623 rewriter.replaceOp (loadOp, result);
621624
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
744747
745748 // Create a mask to clear the destination. E.g., if it is the second i8 in
746749 // i32, 0xFFFF00FF is created.
747- Value mask = rewriter.create <spirv::ConstantOp>(
750+ Value mask = rewriter.createOrFold <spirv::ConstantOp>(
748751 loc, dstType, rewriter.getIntegerAttr (dstType, (1 << srcBits) - 1 ));
749- Value clearBitsMask =
750- rewriter.create <spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
751- clearBitsMask = rewriter.create <spirv::NotOp>(loc, dstType, clearBitsMask);
752+ Value clearBitsMask = rewriter.createOrFold <spirv::ShiftLeftLogicalOp>(
753+ loc, dstType, mask, offset);
754+ clearBitsMask =
755+ rewriter.createOrFold <spirv::NotOp>(loc, dstType, clearBitsMask);
752756
753757 Value storeVal = shiftValue (loc, adaptor.getValue (), offset, mask, rewriter);
754758 Value adjustedPtr = adjustAccessChainForBitwidth (typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
910914
911915 int64_t attrVal = cast<IntegerAttr>(offset.get <Attribute>()).getInt ();
912916 Attribute attr = rewriter.getIntegerAttr (intType, attrVal);
913- return rewriter.create <spirv::ConstantOp>(loc, intType, attr);
917+ return rewriter.createOrFold <spirv::ConstantOp>(loc, intType, attr);
914918 }();
915919
916920 rewriter.replaceOpWithNewOp <spirv::InBoundsPtrAccessChainOp>(
0 commit comments