Skip to content

Commit f39fe3c

Browse files
committed
address comments
1 parent 89cac4d commit f39fe3c

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -430,34 +430,43 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
430430
existingCastOps.push_back(castOp.getOperation());
431431
});
432432

433-
TypeConverter converter;
434-
converter.addConversion([&](Type type) -> Type { return type; });
435-
converter.addConversion(
436-
[&](RankedTensorType type,
437-
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
438-
Type elemTy = type.getElementType();
439-
ArrayRef<int64_t> shape = type.getShape();
440-
441-
int count;
442-
SmallVector<int64_t> subShape;
443-
std::tie(subShape, count) = getSgShapeAndCount(
444-
shape, dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
445-
446-
auto newTy = VectorType::get(subShape, elemTy);
447-
result.append(count, newTy);
448-
return success();
449-
});
450-
451-
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
452-
// VectorType operands. This first converts such operands to RankedTensorType,
453-
// propagates the layout attribute into the encoding attribute, and finally
454-
// converts the RankedTensorType to VectorType based on the encoding.
455-
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
433+
{
434+
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
435+
// VectorType operands. This first converts such operands to
436+
// RankedTensorType, propagates the layout attribute into the encoding
437+
// attribute, and finally converts the RankedTensorType to VectorType based
438+
// on the encoding.
439+
440+
TypeConverter converter;
441+
converter.addConversion([&](Type type) -> Type { return type; });
442+
converter.addConversion(
443+
[&](RankedTensorType type,
444+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
445+
Type elemTy = type.getElementType();
446+
ArrayRef<int64_t> shape = type.getShape();
447+
448+
int count;
449+
SmallVector<int64_t> subShape;
450+
std::tie(subShape, count) = getSgShapeAndCount(
451+
shape,
452+
dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
453+
454+
auto newTy = VectorType::get(subShape, elemTy);
455+
result.append(count, newTy);
456+
return success();
457+
});
458+
459+
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(),
460+
converter);
461+
}
456462

463+
// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
464+
// as well as XeGPU, Arith, and Vector operations.
457465
MLIRContext *ctx = &getContext();
458466
RewritePatternSet patterns(ctx);
459467
ConversionTarget target(*ctx);
460-
468+
TypeConverter converter;
469+
converter.addConversion([&](Type type) -> Type { return type; });
461470
converter.addConversion(
462471
[&](xegpu::TensorDescType type,
463472
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
@@ -516,8 +525,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
516525

517526
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
518527

519-
// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
520-
// as well as XeGPU, Arith, and Vector operations.
521528
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
522529
target);
523530
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);

0 commit comments

Comments
 (0)