diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index e3563d10bc6f1..be7b860dd1729 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -376,10 +376,12 @@ struct WgToSgElementwiseOp : public ConversionPattern { // Copy all attributes, but update "layout_result_0" to drop // sgLayout/sgData for (auto attr : op->getAttrs()) { - if (auto layout = dyn_cast(attr.getValue())) - state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); - else + if (auto layout = dyn_cast(attr.getValue())) { + if (auto newLayout = layout.dropSgLayoutAndData()) + state.addAttribute(attr.getName(), newLayout); + } else { state.addAttribute(attr.getName(), attr.getValue()); + } } Operation *newOp = rewriter.create(state); newResults.push_back(newOp->getResult(0)); @@ -629,8 +631,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { std::string name = xegpu::getLayoutName(result); if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); - if (!isa(op)) - op->setAttr(name, layout.dropSgLayoutAndData()); + if (!isa(op)) { + if (auto newLayout = layout.dropSgLayoutAndData()) + op->setAttr(name, newLayout); + } } } }); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index 64f01d61d6e80..09df1e4da43e2 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -1,6 +1,25 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s gpu.module @test_elementwise_ops { + + // CHECK-LABEL: unary_ops_sg_layout_only + gpu.func @unary_ops_sg_layout_only(%a: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK: math.exp {{.*}} : vector<12x8xf32> + %exp = math.exp %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: arith.negf {{.*}} : vector<12x8xf32> + %negf = arith.negf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + // CHECK-LABEL: unary_ops gpu.func @unary_ops(%a: memref<24x32xf32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>