Skip to content

Commit 96a664f

Browse files
committed
[XPU][OptEW] Allow use-def graphs of elementwise optimizable operations
Allow operands being used by other optimizable operations to enable elementwise operations graph optimizations. Signed-off-by: victor-eds <[email protected]>
1 parent 4923e1c commit 96a664f

File tree

2 files changed

+94
-12
lines changed

2 files changed

+94
-12
lines changed

test/TritonIntelGPU/optimize-elementwise.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
151151
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
152152
}
153153
}
154+
155+
// -----
156+
157+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
158+
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
159+
160+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
161+
162+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
163+
// CHECK-LABEL: tt.func @test_multi_user(
164+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_2:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>)
165+
tt.func @test_multi_user(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg2: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
166+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
167+
// CHECK: %[[VAL_4:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
168+
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : tensor<16xf32, #[[$ATTR_0]]>
169+
%0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
170+
// CHECK: %[[VAL_6:.*]] = triton_gpu.convert_layout %[[VAL_5]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
171+
// CHECK: %[[VAL_7:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
172+
// CHECK: %[[VAL_8:.*]] = triton_gpu.convert_layout %[[VAL_2]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
173+
// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : tensor<16xf32, #[[$ATTR_0]]>
174+
%1 = arith.addf %arg0, %arg2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
175+
// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_9]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
176+
// CHECK: %[[VAL_11:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
177+
// CHECK: %[[VAL_12:.*]] = triton_gpu.convert_layout %[[VAL_10]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
178+
// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] : tensor<16xf32, #[[$ATTR_0]]>
179+
%2 = arith.addf %0, %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
180+
// CHECK: %[[VAL_14:.*]] = triton_gpu.convert_layout %[[VAL_13]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
181+
// CHECK: tt.return %[[VAL_14]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
182+
tt.return %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
183+
}
184+
}

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,47 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
7777
linearLayout, numWorkGroupPos, rewriter);
7878
}
7979

80+
/// Generic checks for the operation not looking at the tensor type.
81+
bool isCandidateOp(Operation *op) {
82+
// Rely on this for a simpler pass.
83+
if (!op->hasTrait<OpTrait::SameOperandsAndResultType>() ||
84+
op->getNumResults() != 1)
85+
return false;
86+
87+
// Skip complex operations.
88+
if (op->hasSuccessors() || op->getNumRegions() != 0)
89+
return false;
90+
91+
return true;
92+
}
93+
94+
bool optimizationDoesNotWorsenRegisterPressure(
95+
Value value, RankedTensorType newType, SmallPtrSetImpl<Value> &visited) {
96+
if (!visited.insert(value).second)
97+
return true;
98+
// All users must be operations we will optimize too or layout conversions we
99+
// will introduce later.
100+
return llvm::all_of(value.getUses(), [&visited, newType](OpOperand &operand) {
101+
Operation *owner = operand.getOwner();
102+
103+
// We will be introducing just this operation later.
104+
if (auto convertLayout = dyn_cast<ConvertLayoutOp>(owner))
105+
return convertLayout.getResult().getType() == newType;
106+
107+
// Only allow candidates. Check only operation constraints. We do not have
108+
// to check the type as we did already.
109+
if (!owner->hasTrait<OpTrait::Elementwise>() || !isCandidateOp(owner))
110+
return false;
111+
112+
// Check other operands fit the constraints.
113+
return llvm::all_of(owner->getOperands(),
114+
[&visited, newType](Value operand) {
115+
return optimizationDoesNotWorsenRegisterPressure(
116+
operand, newType, visited);
117+
});
118+
});
119+
}
120+
80121
/// Get optimized unbroadcasted tensor type.
81122
///
82123
/// Get optimized ranked tensor type after unbroadcasting. As we only support 1D
@@ -110,13 +151,10 @@ struct ElementwiseOptPattern final
110151

111152
LogicalResult matchAndRewrite(Operation *op,
112153
PatternRewriter &rewriter) const final {
113-
// Rely on this for a simpler pass.
114-
if (!op->hasTrait<OpTrait::SameOperandsAndResultType>() ||
115-
op->getNumResults() != 1)
116-
return failure();
154+
LLVM_DEBUG(llvm::dbgs() << "Checking operation:\n" << *op << "\n");
117155

118-
// Skip complex operations.
119-
if (op->hasSuccessors() || op->getNumRegions() != 0)
156+
// Rely on this for a simpler pass.
157+
if (!isCandidateOp(op))
120158
return failure();
121159

122160
// Layout optimizations only apply to tensors.
@@ -132,19 +170,30 @@ struct ElementwiseOptPattern final
132170
return failure();
133171
std::optional<LinearLayout> linearLayout =
134172
toLinearLayout(type.getShape(), layout);
135-
if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter))
136-
return failure();
137173

138-
// Check the operands are not used by other operations. This will prevent
139-
// register pressure increase:
140-
if (!llvm::all_of(op->getOperands(),
141-
[](Value val) { return val.hasOneUse(); }))
174+
LLVM_DEBUG(llvm::dbgs() << "Checking linear layout:\n"
175+
<< linearLayout << "\n");
176+
177+
if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter))
142178
return failure();
143179

144180
// As we are dealing with 1D tensors, we can do a simple transform to obtain
145181
// a more optimized operation.
146182
Location loc = op->getLoc();
147183
RankedTensorType newType = getOptimizedType(type, *linearLayout, rewriter);
184+
185+
LLVM_DEBUG(llvm::dbgs() << "Would convert to type:\n" << newType << "\n");
186+
187+
// Check the operands are not used by other operations. This will prevent
188+
// register pressure increase:
189+
if (SmallPtrSet<Value, 2> visited;
190+
!llvm::all_of(op->getOperands(), [&visited, newType](Value operand) {
191+
return optimizationDoesNotWorsenRegisterPressure(operand, newType,
192+
visited);
193+
}))
194+
return failure();
195+
196+
// Obtain converted operands.
148197
SmallVector<Value> newOperands(op->getNumOperands());
149198
llvm::transform(op->getOperands(), std::begin(newOperands),
150199
[&rewriter, loc, newType](Value operand) {
@@ -164,6 +213,8 @@ struct ElementwiseOptPattern final
164213
Value newValue = newElementwiseOp->getResult(0);
165214
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, type, newValue);
166215

216+
LLVM_DEBUG(llvm::dbgs() << "Conversion took place.\n");
217+
167218
return success();
168219
}
169220
};

0 commit comments

Comments
 (0)