@@ -89,9 +89,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
8989 if (elementType.isF32 ())
9090 return f32 ;
9191 if (elementType.getIntOrFloatBitWidth () < 32 )
92- return rewriter. create < arith::TruncFOp>( loc, desType, f32 );
92+ return arith::TruncFOp::create (rewriter, loc, desType, f32 );
9393 if (elementType.getIntOrFloatBitWidth () > 32 )
94- return rewriter. create < arith::ExtFOp>( loc, desType, f32 );
94+ return arith::ExtFOp::create (rewriter, loc, desType, f32 );
9595 llvm_unreachable (" The only 32-bit float type is f32" );
9696}
9797
@@ -113,26 +113,26 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
113113 Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
114114 VectorType extResType = VectorType::get (2 , rewriter.getF32Type ());
115115 if (!inVecType) {
116- Value asFloat = rewriter. create < amdgpu::ExtPackedFp8Op>(
116+ Value asFloat = amdgpu::ExtPackedFp8Op::create (rewriter,
117117 loc, rewriter.getF32Type (), in, 0 );
118118 Value result = castF32To (outElemType, asFloat, loc, rewriter);
119119 rewriter.replaceOp (op, result);
120120 return success ();
121121 }
122122 int64_t numElements = inVecType.getNumElements ();
123123
124- Value zero = rewriter. create < arith::ConstantOp>(
124+ Value zero = arith::ConstantOp::create (rewriter,
125125 loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
126126 VectorType outType = cast<VectorType>(op.getOut ().getType ());
127127
128128 if (inVecType.getShape ().empty ()) {
129129 Value zerodSplat =
130130 rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
131131 Value scalarIn =
132- rewriter. create < vector::ExtractOp>( loc, in, ArrayRef<int64_t >{});
132+ vector::ExtractOp::create (rewriter, loc, in, ArrayRef<int64_t >{});
133133 Value scalarExt =
134- rewriter. create < arith::ExtFOp>( loc, outElemType, scalarIn);
135- Value result = rewriter. create < vector::InsertOp>( loc, scalarExt, zerodSplat,
134+ arith::ExtFOp::create (rewriter, loc, outElemType, scalarIn);
135+ Value result = vector::InsertOp::create (rewriter, loc, scalarExt, zerodSplat,
136136 ArrayRef<int64_t >{});
137137 rewriter.replaceOp (op, result);
138138 return success ();
@@ -145,32 +145,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
145145 if (inVecType.getRank () > 1 ) {
146146 inVecType = VectorType::get (SmallVector<int64_t >{numElements},
147147 inVecType.getElementType ());
148- in = rewriter. create < vector::ShapeCastOp>( loc, inVecType, in);
148+ in = vector::ShapeCastOp::create (rewriter, loc, inVecType, in);
149149 }
150150
151151 for (int64_t i = 0 ; i < numElements; i += 4 ) {
152152 int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
153- Value inSlice = rewriter. create < vector::ExtractStridedSliceOp>(
153+ Value inSlice = vector::ExtractStridedSliceOp::create (rewriter,
154154 loc, in, i, elemsThisOp, 1 );
155155 for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
156156 if (i + j + 1 < numElements) { // Convert two 8-bit elements
157- Value asFloats = rewriter. create < amdgpu::ExtPackedFp8Op>(
157+ Value asFloats = amdgpu::ExtPackedFp8Op::create (rewriter,
158158 loc, extResType, inSlice, j / 2 );
159159 Type desType = VectorType::get (2 , outElemType);
160160 Value asType = castF32To (desType, asFloats, loc, rewriter);
161- result = rewriter. create < vector::InsertStridedSliceOp>(
161+ result = vector::InsertStridedSliceOp::create (rewriter,
162162 loc, asType, result, i + j, 1 );
163163 } else { // Convert a 8-bit element
164- Value asFloat = rewriter. create < amdgpu::ExtPackedFp8Op>(
164+ Value asFloat = amdgpu::ExtPackedFp8Op::create (rewriter,
165165 loc, rewriter.getF32Type (), inSlice, j / 2 * 2 );
166166 Value asType = castF32To (outElemType, asFloat, loc, rewriter);
167- result = rewriter. create < vector::InsertOp>( loc, asType, result, i + j);
167+ result = vector::InsertOp::create (rewriter, loc, asType, result, i + j);
168168 }
169169 }
170170 }
171171
172172 if (inVecType.getRank () != outType.getRank ()) {
173- result = rewriter. create < vector::ShapeCastOp>( loc, outType, result);
173+ result = vector::ShapeCastOp::create (rewriter, loc, outType, result);
174174 }
175175
176176 rewriter.replaceOp (op, result);
@@ -182,9 +182,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
182182 if (type.isF32 ())
183183 return value;
184184 if (type.getIntOrFloatBitWidth () < 32 )
185- return rewriter. create < arith::ExtFOp>( loc, rewriter.getF32Type (), value);
185+ return arith::ExtFOp::create (rewriter, loc, rewriter.getF32Type (), value);
186186 if (type.getIntOrFloatBitWidth () > 32 )
187- return rewriter. create < arith::TruncFOp>( loc, rewriter.getF32Type (), value);
187+ return arith::TruncFOp::create (rewriter, loc, rewriter.getF32Type (), value);
188188 llvm_unreachable (" The only 32-bit float type is f32" );
189189}
190190
@@ -224,13 +224,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
224224 loc, arith::CmpFPredicate::OEQ, source, negInf);
225225 Value isNan = rewriter.createOrFold <arith::CmpFOp>(
226226 loc, arith::CmpFPredicate::UNO, source, source);
227- Value isNonFinite = rewriter. create < arith::OrIOp>(
228- loc, rewriter. create < arith::OrIOp>( loc, isInf, isNegInf), isNan);
227+ Value isNonFinite = arith::OrIOp::create (rewriter,
228+ loc, arith::OrIOp::create (rewriter, loc, isInf, isNegInf), isNan);
229229
230- Value clampedBelow = rewriter. create < arith::MaximumFOp>( loc, source, minCst);
231- Value clamped = rewriter. create < arith::MinimumFOp>( loc, clampedBelow, maxCst);
230+ Value clampedBelow = arith::MaximumFOp::create (rewriter, loc, source, minCst);
231+ Value clamped = arith::MinimumFOp::create (rewriter, loc, clampedBelow, maxCst);
232232 Value res =
233- rewriter. create < arith::SelectOp>( loc, isNonFinite, source, clamped);
233+ arith::SelectOp::create (rewriter, loc, isNonFinite, source, clamped);
234234 return res;
235235}
236236
@@ -264,24 +264,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
264264 VectorType truncResType = VectorType::get (4 , outElemType);
265265 if (!inVectorTy) {
266266 Value asFloat = castToF32 (in, loc, rewriter);
267- Value asF8s = rewriter. create < amdgpu::PackedTrunc2xFp8Op>(
267+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create (rewriter,
268268 loc, truncResType, asFloat, /* sourceB=*/ nullptr , 0 ,
269269 /* existing=*/ nullptr );
270- Value result = rewriter. create < vector::ExtractOp>( loc, asF8s, 0 );
270+ Value result = vector::ExtractOp::create (rewriter, loc, asF8s, 0 );
271271 rewriter.replaceOp (op, result);
272272 return success ();
273273 }
274274
275275 int64_t numElements = outVecType.getNumElements ();
276- Value zero = rewriter. create < arith::ConstantOp>(
276+ Value zero = arith::ConstantOp::create (rewriter,
277277 loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
278278 if (outVecType.getShape ().empty ()) {
279279 Value scalarIn =
280- rewriter. create < vector::ExtractOp>( loc, in, ArrayRef<int64_t >{});
280+ vector::ExtractOp::create (rewriter, loc, in, ArrayRef<int64_t >{});
281281 // Recurse to send the 0-D vector case to the 1-D vector case
282282 Value scalarTrunc =
283- rewriter. create < arith::TruncFOp>( loc, outElemType, scalarIn);
284- Value result = rewriter. create < vector::InsertOp>( loc, scalarTrunc, zero,
283+ arith::TruncFOp::create (rewriter, loc, outElemType, scalarIn);
284+ Value result = vector::InsertOp::create (rewriter, loc, scalarTrunc, zero,
285285 ArrayRef<int64_t >{});
286286 rewriter.replaceOp (op, result);
287287 return success ();
@@ -294,32 +294,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
294294 if (inVectorTy.getRank () > 1 ) {
295295 inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
296296 inVectorTy.getElementType ());
297- in = rewriter. create < vector::ShapeCastOp>( loc, inVectorTy, in);
297+ in = vector::ShapeCastOp::create (rewriter, loc, inVectorTy, in);
298298 }
299299
300300 for (int64_t i = 0 ; i < numElements; i += 4 ) {
301301 int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
302302 Value thisResult = nullptr ;
303303 for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
304- Value elemA = rewriter. create < vector::ExtractOp>( loc, in, i + j);
304+ Value elemA = vector::ExtractOp::create (rewriter, loc, in, i + j);
305305 Value asFloatA = castToF32 (elemA, loc, rewriter);
306306 Value asFloatB = nullptr ;
307307 if (j + 1 < elemsThisOp) {
308- Value elemB = rewriter. create < vector::ExtractOp>( loc, in, i + j + 1 );
308+ Value elemB = vector::ExtractOp::create (rewriter, loc, in, i + j + 1 );
309309 asFloatB = castToF32 (elemB, loc, rewriter);
310310 }
311- thisResult = rewriter. create < amdgpu::PackedTrunc2xFp8Op>(
311+ thisResult = amdgpu::PackedTrunc2xFp8Op::create (rewriter,
312312 loc, truncResType, asFloatA, asFloatB, j / 2 , thisResult);
313313 }
314314 if (elemsThisOp < 4 )
315- thisResult = rewriter. create < vector::ExtractStridedSliceOp>(
315+ thisResult = vector::ExtractStridedSliceOp::create (rewriter,
316316 loc, thisResult, 0 , elemsThisOp, 1 );
317- result = rewriter. create < vector::InsertStridedSliceOp>( loc, thisResult,
317+ result = vector::InsertStridedSliceOp::create (rewriter, loc, thisResult,
318318 result, i, 1 );
319319 }
320320
321321 if (inVectorTy.getRank () != outVecType.getRank ()) {
322- result = rewriter. create < vector::ShapeCastOp>( loc, outVecType, result);
322+ result = vector::ShapeCastOp::create (rewriter, loc, outVecType, result);
323323 }
324324
325325 rewriter.replaceOp (op, result);
@@ -347,10 +347,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
347347
348348 // Handle the case where input type is not a vector type
349349 if (!inVectorTy) {
350- auto sourceB = rewriter. create < LLVM::PoisonOp>( loc, rewriter.getF32Type ());
350+ auto sourceB = LLVM::PoisonOp::create (rewriter, loc, rewriter.getF32Type ());
351351 Value asF16s =
352- rewriter. create < ROCDL::CvtPkRtz>( loc, truncResType, in, sourceB);
353- Value result = rewriter. create < vector::ExtractOp>( loc, asF16s, 0 );
352+ ROCDL::CvtPkRtz::create (rewriter, loc, truncResType, in, sourceB);
353+ Value result = vector::ExtractOp::create (rewriter, loc, asF16s, 0 );
354354 rewriter.replaceOp (op, result);
355355 return success ();
356356 }
@@ -362,33 +362,33 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
362362 if (inVectorTy.getRank () > 1 ) {
363363 inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
364364 inVectorTy.getElementType ());
365- in = rewriter. create < vector::ShapeCastOp>( loc, inVectorTy, in);
365+ in = vector::ShapeCastOp::create (rewriter, loc, inVectorTy, in);
366366 }
367367
368368 // Handle the vector case. We also handle the (uncommon) case where the vector
369369 // length is odd
370370 for (int64_t i = 0 ; i < numElements; i += 2 ) {
371371 int64_t elemsThisOp = std::min (numElements, i + 2 ) - i;
372372 Value thisResult = nullptr ;
373- Value elemA = rewriter. create < vector::ExtractOp>( loc, in, i);
374- Value elemB = rewriter. create < LLVM::PoisonOp>( loc, rewriter.getF32Type ());
373+ Value elemA = vector::ExtractOp::create (rewriter, loc, in, i);
374+ Value elemB = LLVM::PoisonOp::create (rewriter, loc, rewriter.getF32Type ());
375375
376376 if (elemsThisOp == 2 ) {
377- elemB = rewriter. create < vector::ExtractOp>( loc, in, i + 1 );
377+ elemB = vector::ExtractOp::create (rewriter, loc, in, i + 1 );
378378 }
379379
380380 thisResult =
381- rewriter. create < ROCDL::CvtPkRtz>( loc, truncResType, elemA, elemB);
381+ ROCDL::CvtPkRtz::create (rewriter, loc, truncResType, elemA, elemB);
382382 // Place back the truncated result into the possibly larger vector. If we
383383 // are operating on a size 2 vector, these operations should be folded away
384- thisResult = rewriter. create < vector::ExtractStridedSliceOp>(
384+ thisResult = vector::ExtractStridedSliceOp::create (rewriter,
385385 loc, thisResult, 0 , elemsThisOp, 1 );
386- result = rewriter. create < vector::InsertStridedSliceOp>( loc, thisResult,
386+ result = vector::InsertStridedSliceOp::create (rewriter, loc, thisResult,
387387 result, i, 1 );
388388 }
389389
390390 if (inVectorTy.getRank () != outVecType.getRank ()) {
391- result = rewriter. create < vector::ShapeCastOp>( loc, outVecType, result);
391+ result = vector::ShapeCastOp::create (rewriter, loc, outVecType, result);
392392 }
393393
394394 rewriter.replaceOp (op, result);
0 commit comments