@@ -329,13 +329,13 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
329329
330330// Handles UnrealizedConversionCastOp generated during
331331// SCFStructuralTypeConversions (step 1). This op may appear as either a
332- // target or source materialization for Vector or TensorDesc, e.g.:
333- // 1. unrealized_conversion_cast %1 : tensor_desc<16xf16 > to
334- // tensor_desc<128xf16 , ...>
335- // 2. unrealized_conversion_cast %1 : vector<256xf32> to vector<16xf32>, ...
336- // 3. unrealized_conversion_cast %1 : vector<16xf32>, ... to vector<256xf32>
337- // In all cases, the pattern simply forwards the inputs to the outputs with
338- // one-to-one or one-to-n patterns .
332+ // target or source materialization for Vector or TensorDesc values , e.g.:
333+ // 1. unrealized_cast %1 : vector<256xf32 > to vector<16xf32>, ...
334+ // 2. unrealized_cast %1 : vector<16xf32> , ... to vector<256xf32 >
335+ // it could be either 1:1, 1:N or N:1 cast. In all cases, the pattern
336+ // simply forwards the inputs to the outputs using 1:1 or 1:N interface.
337+ // TODO: remove it when context-aware type converter is ready.
338+ // It is safe only when input codes don't contain UnrealizedConversionCastOp .
339339struct UnrealizedConversionCastOpPattern
340340 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
341341 using OpConversionPattern<
@@ -346,34 +346,28 @@ struct UnrealizedConversionCastOpPattern
346346 ConversionPatternRewriter &rewriter) const override {
347347 SmallVector<Value> inputs = xegpu::flattenValues (adaptor.getInputs ());
348348
349- // Handles the case where cast %1 : tensor_desc<16xf16> to
350- // tensor_desc<128xf16, ...> The input values provided by the adaptor should
351- // already be distributed.
352- if (op.getNumOperands () == 1 && op.getNumResults () == 1 &&
353- isa<xegpu::TensorDescType>(op->getOperand (0 ).getType ()) &&
354- isa<xegpu::TensorDescType>(op->getResult (0 ).getType ())) {
355- rewriter.replaceOp (op, inputs);
356- return success ();
357- }
349+ auto inputTy = inputs[0 ].getType ();
350+ auto outputTy = op->getOpResult (0 ).getType ();
351+
352+ if (!llvm::all_equal (op->getResultTypes ()) ||
353+ !llvm::all_equal (ValueRange (inputs).getTypes ()) ||
354+ !isa<VectorType, xegpu::TensorDescType>(inputTy) ||
355+ !isa<VectorType, xegpu::TensorDescType>(outputTy))
356+ return failure ();
358357
359358 // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
360359 // the input values provided by the adaptor should already be distributed,
361360 // and their types should correspond exactly to the result types of the
362361 // operation.
363- if (op.getNumOperands () == 1 &&
364- llvm::equal (ValueRange (inputs).getTypes (), op->getResultTypes ())) {
362+ if (op.getNumOperands () == 1 ) {
365363 rewriter.replaceOp (op, inputs);
366364 return success ();
367365 }
368366
369367 // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
370368 // All input values must have the same vector type, and their shape must be
371369 // evenly divisible by the output vector's shape.
372- auto inputTy = dyn_cast<VectorType>(inputs[0 ].getType ());
373- auto outputTy = dyn_cast<VectorType>(op->getOpResult (0 ).getType ());
374- if (op.getNumResults () == 1 && inputTy && outputTy &&
375- llvm::all_equal (ValueRange (inputs).getTypes ()) &&
376- computeShapeRatio (outputTy.getShape (), inputTy.getShape ())) {
370+ if (op.getNumResults () == 1 ) {
377371 rewriter.replaceOpWithMultiple (op, {inputs});
378372 return success ();
379373 }
0 commit comments