Skip to content

Commit 4b517b9

Browse files
committed
Further test
1 parent cc6c199 commit 4b517b9

File tree

3 files changed

+141
-111
lines changed

3 files changed

+141
-111
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
263263
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
264264
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
265265
atol = 1e-1 if N_CTX == 16384 else 1e-2
266-
#benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
266+
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
267267
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
268268
kernel_name='_attn_fwd')
269269

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,6 @@ RankedTensorType getOptimizedType(RankedTensorType type,
142142
[[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum());
143143
assert(ctaSplitNum == 1 && "Expecting single CTA");
144144

145-
llvm::errs() << linearLayout << "\n";
146-
147145
RankedTensorType::Builder typeBuilder(type);
148146
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
149147
unsigned sizePerThread =

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 140 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1616
#include "triton/Dialect/Triton/IR/Dialect.h"
17+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1718

1819
#define DEBUG_TYPE "tritonintelgpu-optimize-reduction-locality"
1920

@@ -146,11 +147,15 @@ namespace {
146147
struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
147148
using OpRewritePattern<ReduceOp>::OpRewritePattern;
148149

150+
// Original reduction
149151
static constexpr int preferredNonReductionAxis = 0;
150-
static constexpr int finalReductionAxis = 3;
151152
static constexpr int preferredReductionAxis = 1;
152-
static constexpr int repCountReshapedAxis = 4;
153-
static constexpr int withinWarpXAxisReshapedAxis = 6;
153+
154+
// Intermediate reductions
155+
static constexpr int finalEWReductionAxis = 0;
156+
static constexpr int finalWarpsReductionAxis = 2;
157+
static constexpr int repCountReshapedAxis = 2;
158+
static constexpr int withinWarpXAxisReshapedAxis = 5;
154159

155160
LogicalResult matchAndRewrite(ReduceOp op,
156161
PatternRewriter &rewriter) const final {
@@ -185,39 +190,44 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
185190

186191
LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n");
187192

188-
operand = reshapeForElementWiseReduction(op, rewriter);
193+
operand = reshapeForElementWiseReduction(op, rewriter, encoding);
189194

190195
LLVM_DEBUG(llvm::dbgs()
191196
<< "Reshaped for elementwise reduction: " << operand << "\n");
192197

198+
operand = performElementWiseReductionWithinRepCount(op, rewriter, operand);
199+
200+
LLVM_DEBUG(llvm::dbgs()
201+
<< "Performed elementwise reduction within repCount: " << operand
202+
<< "\n");
203+
193204
operand = performElementWiseReductionAcrossRepCounts(op, rewriter, operand);
194205

195206
LLVM_DEBUG(llvm::dbgs()
196207
<< "Performed elementwise reduction across repCount: " << operand
197208
<< "\n");
198209

199-
operand = performElementWiseReductionWithinRepCount(op, rewriter, operand);
210+
operand = reshapeForFinalReduction(op, rewriter, operand, encoding);
200211

201212
LLVM_DEBUG(llvm::dbgs()
202-
<< "Performed elementwise reduction within repCount: " << operand
203-
<< "\n");
213+
<< "Reshaped for final reduction: " << operand << "\n");
204214

205215
operand = convertLayoutForFinalReduction(op, rewriter, operand);
206216

207217
LLVM_DEBUG(llvm::dbgs()
208218
<< "Converted layout for final reduction: " << operand << "\n");
209219

210-
operand = reshapeForFinalReduction(op, rewriter, operand);
220+
operand = performFinalElementwiseReduction(op, rewriter, operand);
211221

212222
LLVM_DEBUG(llvm::dbgs()
213-
<< "Reshaped for final reduction: " << operand << "\n");
223+
<< "Final elementwise reduction performed: " << operand << "\n");
214224

215-
operand = performFinalReduction(op, rewriter, operand);
225+
operand = performFinalAcrossWarpsReduction(op, rewriter, operand);
216226

217-
LLVM_DEBUG(llvm::dbgs()
218-
<< "Final reduction performed: " << operand << "\n");
227+
LLVM_DEBUG(llvm::dbgs() << "Final across-warps reduction performed: "
228+
<< operand << "\n");
219229

220-
operand = convertLayoutToOriginalType(op, rewriter, operand);
230+
operand = convertLayoutToOriginalType(op, rewriter, operand, encoding);
221231

222232
LLVM_DEBUG(llvm::dbgs()
223233
<< "Converted layout to original type: " << operand << "\n");
@@ -233,57 +243,65 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
233243
}
234244

235245
private:
236-
Value reshapeForElementWiseReduction(ReduceOp op,
237-
PatternRewriter &rewriter) const {
246+
Value reshapeForElementWiseReduction(ReduceOp op, PatternRewriter &rewriter,
247+
DpasEncodingAttr dpasEncoding) const {
238248
assert(op.getOperands().size() == 1 && "Expecting a single operand");
239249

240250
Value val = op.getOperands().front();
241251
auto oldType = cast<RankedTensorType>(val.getType());
242252
ArrayRef<int64_t> oldShape = oldType.getShape();
243-
auto oldEncoding = cast<DpasEncodingAttr>(oldType.getEncoding());
244253

245-
constexpr size_t rank = 7;
254+
constexpr size_t rank = 8;
246255
std::array<int64_t, rank> shape{
247-
// Y axis contiguous elements handled by a single thread.
248-
oldEncoding.getExecutionSize(),
249-
// Y axis contiguous elements handled by a single thread.
250-
// Needs to be split from previous dimension to perform transpose.
251-
(oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]) /
252-
oldEncoding.getExecutionSize(),
253-
// Y axis rest.
254-
oldShape[0] /
255-
(oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]),
256-
// X axis contiguous elements distributed within individual threads in a
257-
// warp.
258-
oldEncoding.getExecutionSize(),
259-
// X axis contiguous elements distributed within a warp.
260-
oldEncoding.getRepCluster()[1],
261-
// X axis number of warps.
262-
oldEncoding.getWarpsPerCTA()[1],
263-
// X axis rest.
256+
dpasEncoding.getExecutionSize(),
257+
dpasEncoding.getRepeatCount(),
258+
dpasEncoding.getRepCluster()[1],
259+
dpasEncoding.getRepCluster()[0],
260+
dpasEncoding.getWarpsPerCTA()[1],
261+
dpasEncoding.getWarpsPerCTA()[0],
264262
oldShape[1] /
265-
(oldEncoding.getExecutionSize() * oldEncoding.getRepCluster()[1] *
266-
oldEncoding.getWarpsPerCTA()[1])};
263+
(dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] *
264+
dpasEncoding.getWarpsPerCTA()[1]),
265+
oldShape[0] /
266+
(dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] *
267+
dpasEncoding.getWarpsPerCTA()[0])};
267268
std::array<unsigned, rank> sizePerThread{
268-
oldEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1};
269+
1,
270+
dpasEncoding.getRepeatCount(),
271+
dpasEncoding.getRepCluster()[1],
272+
dpasEncoding.getRepCluster()[0],
273+
1,
274+
1,
275+
static_cast<unsigned>(oldShape[1]) /
276+
(dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] *
277+
dpasEncoding.getWarpsPerCTA()[1]),
278+
static_cast<unsigned>(oldShape[0]) /
279+
(dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] *
280+
dpasEncoding.getWarpsPerCTA()[0])};
269281
std::array<unsigned, rank> threadsPerWarp{
270-
1, 1, 1, oldEncoding.getExecutionSize(), 1, 1, 1};
271-
std::array<unsigned, rank> warpsPerCTA{
272-
1, 1, oldEncoding.getWarpsPerCTA()[0],
273-
1, 1, oldEncoding.getWarpsPerCTA()[1],
274-
1};
275-
std::array<unsigned, rank> order{3, 4, 5, 6, 0, 1, 2};
282+
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1, 1};
283+
std::array<unsigned, rank> warpsPerCTA{1,
284+
1,
285+
1,
286+
1,
287+
dpasEncoding.getWarpsPerCTA()[1],
288+
dpasEncoding.getWarpsPerCTA()[0],
289+
1,
290+
1};
291+
std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5, 6, 7};
276292
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
277293

278294
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
279295
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
280296

281-
RankedTensorType type =
282-
RankedTensorType::get(shape, oldType.getElementType(), encoding);
297+
RankedTensorType::Builder type(oldType);
298+
type.setShape(shape);
299+
type.setEncoding(encoding);
283300

284301
// Although this is a NOP, we have to pass allow_reorder=true as static
285302
// analysis will fail to infer it.
286-
return rewriter.create<ReshapeOp>(op.getLoc(), type, val,
303+
return rewriter.create<ReshapeOp>(op.getLoc(),
304+
static_cast<RankedTensorType>(type), val,
287305
/*allow_reorder=*/true,
288306
/*efficient_layout=*/true);
289307
}
@@ -315,100 +333,114 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
315333

316334
Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter,
317335
Value val) const {
318-
assert(op.getOperands().size() == 1 && "Expecting a single operand");
319-
320336
auto oldType = cast<RankedTensorType>(val.getType());
321-
auto dpasEncoding = cast<DpasEncodingAttr>(
322-
cast<RankedTensorType>(op.getOperands().front().getType())
323-
.getEncoding());
337+
auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding());
338+
RankedTensorType::Builder type(oldType);
324339

325-
constexpr size_t rank = 5;
326-
ArrayRef<int64_t> shape = oldType.getShape();
327-
std::array<unsigned, rank> sizePerThread{
328-
1, 1, 1, dpasEncoding.getExecutionSize(), 1};
329-
std::array<unsigned, rank> threadsPerWarp{dpasEncoding.getExecutionSize(),
330-
1, 1, 1, 1};
331-
std::array<unsigned, rank> warpsPerCTA{1, 1,
332-
dpasEncoding.getWarpsPerCTA()[0], 1,
333-
dpasEncoding.getWarpsPerCTA()[1]};
334-
std::array<unsigned, rank> order{3, 4, 0, 1, 2};
335-
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
340+
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
341+
SmallVector<unsigned> threadsPerWarp = oldEncoding.getThreadsPerWarp();
342+
343+
std::swap(sizePerThread[0], sizePerThread[1]);
344+
std::swap(threadsPerWarp[0], threadsPerWarp[1]);
336345

337346
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
338-
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
347+
sizePerThread, threadsPerWarp, oldEncoding.getWarpsPerCTA(),
348+
oldEncoding.getOrder(), oldEncoding.getCTALayout());
339349

340-
RankedTensorType type =
341-
RankedTensorType::get(shape, oldType.getElementType(), encoding);
350+
type.setEncoding(encoding);
342351

343-
return rewriter.create<ConvertLayoutOp>(op.getLoc(), type, val);
352+
return rewriter.create<ConvertLayoutOp>(
353+
op.getLoc(), static_cast<RankedTensorType>(type), val);
344354
}
345355

346356
Value reshapeForFinalReduction(ReduceOp op, PatternRewriter &rewriter,
347-
Value val) const {
357+
Value val,
358+
DpasEncodingAttr dpasEncoding) const {
348359
auto oldType = cast<RankedTensorType>(val.getType());
349360
ArrayRef<int64_t> oldShape = oldType.getShape();
350-
auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding());
351361

352-
constexpr size_t rank = 4;
353-
std::array<int64_t, rank> shape{oldShape[0], oldShape[1], oldShape[2],
354-
oldShape[3] * oldShape[4]};
355-
std::array<unsigned, rank> sizePerThread{1, 1, 1,
356-
oldEncoding.getSizePerThread()[3]};
362+
constexpr size_t rank = 6;
363+
std::array<int64_t, rank> shape{dpasEncoding.getExecutionSize(),
364+
dpasEncoding.getExecutionSize(),
365+
dpasEncoding.getRepeatCount() *
366+
dpasEncoding.getRepCluster()[0] /
367+
dpasEncoding.getExecutionSize(),
368+
dpasEncoding.getWarpsPerCTA()[1],
369+
dpasEncoding.getWarpsPerCTA()[0],
370+
oldShape.back()};
371+
std::array<unsigned, rank> sizePerThread{
372+
1,
373+
dpasEncoding.getExecutionSize(),
374+
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] /
375+
dpasEncoding.getExecutionSize(),
376+
1,
377+
1,
378+
static_cast<unsigned>(oldShape.back())};
357379
std::array<unsigned, rank> threadsPerWarp{
358-
oldEncoding.getThreadsPerWarp()[0], 1, 1, 1};
359-
std::array<unsigned, rank> warpsPerCTA{
360-
1, 1, oldEncoding.getWarpsPerCTA()[2], oldEncoding.getWarpsPerCTA()[4]};
361-
std::array<unsigned, rank> order{3, 0, 1, 2};
380+
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1};
381+
std::array<unsigned, rank> warpsPerCTA{1,
382+
1,
383+
1,
384+
dpasEncoding.getWarpsPerCTA()[1],
385+
dpasEncoding.getWarpsPerCTA()[0],
386+
1};
387+
std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5};
362388
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
363389

364390
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
365391
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
366392

367-
RankedTensorType type =
368-
RankedTensorType::get(shape, oldType.getElementType(), encoding);
393+
RankedTensorType::Builder type(oldType);
394+
type.setShape(shape);
395+
type.setEncoding(encoding);
369396

370397
// Although this is a NOP, we have to pass allow_reorder=true as static
371398
// analysis will fail to infer it.
372-
return rewriter.create<ReshapeOp>(op.getLoc(), type, val,
399+
return rewriter.create<ReshapeOp>(op.getLoc(),
400+
static_cast<RankedTensorType>(type), val,
373401
/*allow_reorder=*/true,
374402
/*efficient_layout=*/true);
375403
}
376404

377-
Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter,
378-
Value val) const {
379-
return performReduction(op, rewriter, val, /*axis=*/finalReductionAxis);
405+
Value performFinalElementwiseReduction(ReduceOp op, PatternRewriter &rewriter,
406+
Value val) const {
407+
return performReduction(op, rewriter, val, /*axis=*/finalEWReductionAxis);
408+
}
409+
410+
Value performFinalAcrossWarpsReduction(ReduceOp op, PatternRewriter &rewriter,
411+
Value val) const {
412+
return performReduction(op, rewriter, val,
413+
/*axis=*/finalWarpsReductionAxis);
380414
}
381415

382416
Value convertLayoutToOriginalType(ReduceOp op, PatternRewriter &rewriter,
383-
Value val) const {
417+
Value val,
418+
DpasEncodingAttr dpasEncoding) const {
384419
auto oldType = cast<RankedTensorType>(val.getType());
385-
auto dpasEncoding = cast<DpasEncodingAttr>(
386-
cast<RankedTensorType>(op.getOperands().front().getType())
387-
.getEncoding());
388-
389-
// Only Y axis (X axis has already been reduced)
390-
constexpr size_t rankBeforeLastReduction = 4;
391-
ArrayRef<int64_t> shape = oldType.getShape();
392-
std::array<unsigned, rankBeforeLastReduction> sizePerThread{
393-
dpasEncoding.getExecutionSize(), 1, 1, 1};
394-
std::array<unsigned, rankBeforeLastReduction> threadsPerWarp{
395-
1, 1, 1, dpasEncoding.getExecutionSize()};
396-
std::array<unsigned, rankBeforeLastReduction> warpsPerCTA{
397-
1, 1, dpasEncoding.getWarpsPerCTA()[0],
398-
dpasEncoding.getWarpsPerCTA()[1]};
399-
std::array<unsigned, rankBeforeLastReduction> order{3, 0, 1, 2};
400-
CTALayoutAttr ctaLayout =
401-
CTALayoutAttr::getDefault(getContext(), rankBeforeLastReduction);
402-
403-
auto blockedEncoding = rewriter.getAttr<BlockedEncodingAttr>(
420+
ArrayRef<int64_t> oldShape = oldType.getShape();
421+
RankedTensorType::Builder type(oldType);
422+
423+
constexpr size_t rank = 5;
424+
std::array<unsigned, rank> sizePerThread{
425+
dpasEncoding.getExecutionSize(),
426+
dpasEncoding.getRepCluster()[0] * dpasEncoding.getRepeatCount() /
427+
dpasEncoding.getExecutionSize(),
428+
1, 1, 1};
429+
std::array<unsigned, rank> threadsPerWarp{1, 1, 1, 1,
430+
dpasEncoding.getExecutionSize()};
431+
std::array<unsigned, rank> warpsPerCTA{1, 1,
432+
dpasEncoding.getWarpsPerCTA()[0], 1,
433+
dpasEncoding.getWarpsPerCTA()[1]};
434+
std::array<unsigned, rank> order{0, 1, 2, 3, 4};
435+
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
436+
437+
auto parentEncoding = rewriter.getAttr<BlockedEncodingAttr>(
404438
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
405-
auto encoding = rewriter.getAttr<SliceEncodingAttr>(finalReductionAxis,
406-
blockedEncoding);
407439

408-
RankedTensorType type =
409-
RankedTensorType::get(shape, oldType.getElementType(), encoding);
440+
type.setEncoding(parentEncoding.squeeze(rank - 1));
410441

411-
return rewriter.create<ConvertLayoutOp>(op.getLoc(), type, val);
442+
return rewriter.create<ConvertLayoutOp>(
443+
op.getLoc(), static_cast<RankedTensorType>(type), val);
412444
}
413445

414446
Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter,

0 commit comments

Comments
 (0)