@@ -430,34 +430,43 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
430430 existingCastOps.push_back (castOp.getOperation ());
431431 });
432432
433- TypeConverter converter;
434- converter.addConversion ([&](Type type) -> Type { return type; });
435- converter.addConversion (
436- [&](RankedTensorType type,
437- SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
438- Type elemTy = type.getElementType ();
439- ArrayRef<int64_t > shape = type.getShape ();
440-
441- int count;
442- SmallVector<int64_t > subShape;
443- std::tie (subShape, count) = getSgShapeAndCount (
444- shape, dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding ()));
445-
446- auto newTy = VectorType::get (subShape, elemTy);
447- result.append (count, newTy);
448- return success ();
449- });
450-
451- // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
452- // VectorType operands. This first converts such operands to RankedTensorType,
453- // propagates the layout attribute into the encoding attribute, and finally
454- // converts the RankedTensorType to VectorType based on the encoding.
455- xegpu::doSCFStructuralTypeConversionWithTensorType (getOperation (), converter);
433+ {
434+ // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
435+ // VectorType operands. This first converts such operands to
436+ // RankedTensorType, propagates the layout attribute into the encoding
437+ // attribute, and finally converts the RankedTensorType to VectorType based
438+ // on the encoding.
439+
440+ TypeConverter converter;
441+ converter.addConversion ([&](Type type) -> Type { return type; });
442+ converter.addConversion (
443+ [&](RankedTensorType type,
444+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
445+ Type elemTy = type.getElementType ();
446+ ArrayRef<int64_t > shape = type.getShape ();
447+
448+ int count;
449+ SmallVector<int64_t > subShape;
450+ std::tie (subShape, count) = getSgShapeAndCount (
451+ shape,
452+ dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding ()));
453+
454+ auto newTy = VectorType::get (subShape, elemTy);
455+ result.append (count, newTy);
456+ return success ();
457+ });
458+
459+ xegpu::doSCFStructuralTypeConversionWithTensorType (getOperation (),
460+ converter);
461+ }
456462
463+ // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
464+ // as well as XeGPU, Arith, and Vector operations.
457465 MLIRContext *ctx = &getContext ();
458466 RewritePatternSet patterns (ctx);
459467 ConversionTarget target (*ctx);
460-
468+ TypeConverter converter;
469+ converter.addConversion ([&](Type type) -> Type { return type; });
461470 converter.addConversion (
462471 [&](xegpu::TensorDescType type,
463472 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
@@ -516,8 +525,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
516525
517526 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
518527
519- // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
520- // as well as XeGPU, Arith, and Vector operations.
521528 scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
522529 target);
523530 xegpu::populateXeGPUWgToSgDistributePatterns (patterns);
0 commit comments