Skip to content

Commit 1c2f1a8

Browse files
authored
Merge branch 'main' into sub-group-slm-transpose
2 parents 57b4375 + 24e53d2 commit 1c2f1a8

File tree

13 files changed

+354
-12
lines changed

13 files changed

+354
-12
lines changed

.github/workflows/e2e-accuracy.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ on:
4444
- all
4545
- subset
4646
default: all
47+
check_all_subset_models:
48+
description: In "subset" mode, check all subset models
49+
type: boolean
50+
default: false
4751
only_one_model:
4852
description: Run only this one model
4953
type: string
@@ -125,6 +129,7 @@ jobs:
125129
test_mode: accuracy
126130
dtype: ${{ matrix.dtype }}
127131
models: ${{ inputs.models }}
132+
check_all_subset_models: ${{ inputs.check_all_subset_models || false }}
128133
only_one_model: ${{ inputs.only_one_model }}
129134
runner_label: ${{ inputs.runner_label }}
130135
TORCH_COMPILE_DEBUG: ${{ inputs.TORCH_COMPILE_DEBUG }}

.github/workflows/e2e-performance.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ on:
4444
- all
4545
- subset
4646
default: subset
47+
check_all_subset_models:
48+
description: In "subset" mode, do not fail workflow if one of models failed
49+
type: boolean
50+
default: false
4751
only_one_model:
4852
description: Run only this one model
4953
type: string
@@ -136,6 +140,7 @@ jobs:
136140
test_mode: performance
137141
dtype: ${{ matrix.dtype }}
138142
models: ${{ inputs.models }}
143+
check_all_subset_models: ${{ inputs.check_all_subset_models || false }}
139144
only_one_model: ${{ inputs.only_one_model }}
140145
runner_label: ${{ inputs.runner_label }}
141146
TORCH_COMPILE_DEBUG: ${{ inputs.TORCH_COMPILE_DEBUG }}

.github/workflows/e2e-reusable.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ on:
2727
description: Run all models or a subset
2828
type: string
2929
default: all
30+
check_all_subset_models:
31+
description: In "subset" mode, check all subset models
32+
type: boolean
33+
default: false
3034
only_one_model:
3135
description: Run only this one model
3236
type: string
@@ -224,9 +228,19 @@ jobs:
224228
if [[ "${{ inputs.only_one_model }}" ]]; then
225229
bash -e $GITHUB_WORKSPACE/scripts/inductor_xpu_test.sh ${{ inputs.suite }} ${{ inputs.dtype }} ${{ inputs.mode }} ${{ inputs.test_mode }} xpu 0 static 1 0 ${{ inputs.only_one_model }}
226230
elif [[ "${{ inputs.models }}" == "subset" ]]; then
231+
models_subset_file="$GITHUB_WORKSPACE/.github/models/${{ inputs.test_mode }}/${{ inputs.suite }}.txt"
227232
while read model; do
228233
bash -e $GITHUB_WORKSPACE/scripts/inductor_xpu_test.sh ${{ inputs.suite }} ${{ inputs.dtype }} ${{ inputs.mode }} ${{ inputs.test_mode }} xpu 0 static 1 0 $model
229-
done < $GITHUB_WORKSPACE/.github/models/${{ inputs.test_mode }}/${{ inputs.suite }}.txt
234+
done < $models_subset_file
235+
if [[ "${{ inputs.check_all_subset_models }}" == true ]]; then
236+
python $GITHUB_WORKSPACE/scripts/check_inductor_report.py --models-file="$models_subset_file" \
237+
--suite=${{ inputs.suite }} \
238+
--dtype=${{ inputs.dtype }} \
239+
--mode=${{ inputs.mode }} \
240+
--test_mode=${{ inputs.test_mode }} \
241+
--device=xpu \
242+
--inductor-log-dir="${GITHUB_WORKSPACE}/inductor_log"
243+
fi
230244
else
231245
bash -e $GITHUB_WORKSPACE/scripts/inductor_xpu_test.sh ${{ inputs.suite }} ${{ inputs.dtype }} ${{ inputs.mode }} ${{ inputs.test_mode }} xpu 0 static 1 0
232246
fi

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(q, k, v, causal, sm_scale):
171171
assert Lk in {16, 32, 64, 128}
172172
o = torch.empty_like(q, dtype=torch.float32)
173173
BLOCK_M = 128
174-
BLOCK_N = 64 if Lk <= 64 else 32
174+
BLOCK_N = 64
175175
num_stages = 3
176176
num_warps = 8 if Lq == 64 else 16
177177
stage = 3 if causal else 1
@@ -205,7 +205,8 @@ def forward(q, k, v, causal, sm_scale):
205205
BLOCK_DMODEL=Lk, #
206206
STAGE=stage, #
207207
num_warps=num_warps, #
208-
num_stages=num_stages #
208+
num_stages=num_stages, #
209+
grf_mode='large', #
209210
)
210211
return o
211212

scripts/check_inductor_report.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
from pathlib import Path
4+
import csv
5+
import sys
6+
7+
8+
def check_report(suite, dtype, mode, test_mode, device, models_file, inductor_log_dir):
9+
inductor_log_dir_leaf = Path(inductor_log_dir) / suite / dtype
10+
inductor_report_filename = f"inductor_{suite}_{dtype}_{mode}_{device}_{test_mode}.csv"
11+
inductor_report_path = Path(inductor_log_dir_leaf / inductor_report_filename)
12+
13+
subset = []
14+
report = []
15+
exitcode = 0
16+
17+
with open(models_file, encoding="utf-8") as f:
18+
subset = f.read().splitlines()
19+
20+
with open(inductor_report_path, encoding="utf-8") as f:
21+
reader = csv.reader(f)
22+
report_with_header = []
23+
for l in reader:
24+
report_with_header.append(l)
25+
for r in report_with_header[1:]:
26+
if r[0] == device:
27+
report.append(r)
28+
29+
test_list = [r[1] for r in report]
30+
31+
if test_mode == "performance":
32+
for m in subset:
33+
if m not in test_list:
34+
exitcode = 1
35+
print(f"Test is not found in report: {m}")
36+
37+
if test_mode == "accuracy":
38+
test_statuses = [r[3] for r in report]
39+
for m in subset:
40+
try:
41+
idx = test_list.index(m)
42+
except ValueError:
43+
exitcode = 1
44+
print(f"Test is NOT FOUND: {m}")
45+
continue
46+
if test_statuses[idx] != "pass":
47+
exitcode = 1
48+
print(f"Test is NOT PASSED: {m}")
49+
return exitcode
50+
51+
52+
def main():
53+
argparser = argparse.ArgumentParser()
54+
argparser.add_argument("--suite", required=True)
55+
argparser.add_argument("--dtype", required=True)
56+
argparser.add_argument("--mode", required=True, choices=("inference", "training", "inference-no-freezing"))
57+
argparser.add_argument("--test_mode", required=True, choices=("performance", "accuracy"))
58+
argparser.add_argument("--device", help="i.e. xpu", required=True)
59+
argparser.add_argument("--models-file", help="Subset of models list", required=True)
60+
argparser.add_argument("--inductor-log-dir", help="Inductor test log directory", default="inductor_log")
61+
args = argparser.parse_args()
62+
exitcode = check_report(args.suite, args.dtype, args.mode, args.test_mode, args.device, args.models_file,
63+
args.inductor_log_dir)
64+
print(f"Report check result: {'SUCCESS' if exitcode == 0 else 'FAIL'}")
65+
sys.exit(exitcode)
66+
67+
68+
if __name__ == "__main__":
69+
main()

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def make_ttgir(mod, metadata, opt, properties):
238238
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
239239
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
240240

241-
passes.ttgpuir.add_coalesce(pm)
241+
intel.passes.ttgpuir.add_coalesce(pm)
242242
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
243243
passes.ttgpuir.add_optimize_thread_locality(pm)
244244
passes.ttgpuir.add_optimize_dot_operands(pm, True)

third_party/intel/include/Analysis/AxisInfo.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ class AxisInfo {
2727
public:
2828
AxisInfo() : AxisInfo({}, {}, {}) {}
2929

30-
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
30+
AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility,
31+
const DimVectorT &constancy)
3132
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}
3233

33-
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
34-
std::optional<int64_t> constantValue)
34+
AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility,
35+
const DimVectorT &constancy, std::optional<int64_t> constantValue)
3536
: contiguity(contiguity), divisibility(divisibility),
3637
constancy(constancy), constantValue(constantValue) {
3738
assert(divisibility.size() == contiguity.size());

third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,37 @@
99
#ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H
1010
#define TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H
1111

12-
#include <optional>
13-
12+
#include "intel/include/Analysis/AxisInfo.h"
13+
#include "mlir/IR/Operation.h"
14+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1415
#include <triton/Tools/Sys/GetEnv.hpp>
1516

1617
namespace mlir::triton::gpu::intel {
18+
19+
/// Calculate the optimal number of elements per thread for a given operation
20+
/// along an axis with greatest continuity.
21+
inline unsigned getNumElementsPerThread(
22+
Operation *op, SmallVector<unsigned> order,
23+
mlir::triton::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
24+
Value val = getMemAccessPtr(op);
25+
Type valTy = val.getType();
26+
auto ty =
27+
isTensorPointerType(valTy)
28+
? cast<RankedTensorType>(cast<PointerType>(valTy).getPointeeType())
29+
: cast<RankedTensorType>(valTy);
30+
auto shapePerCTA = getShapePerCTA(ty);
31+
mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);
32+
33+
unsigned elemNumBits = getElementBitWidth(ty);
34+
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
35+
unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]);
36+
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
37+
unsigned maxContig =
38+
std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]);
39+
unsigned alignment = std::min(maxMultiple, maxContig);
40+
return std::min(alignment, 128 / elemNumBits);
41+
}
42+
1743
/// Check whether transposed reduction should be performed.
1844
///
1945
/// See: https://github.com/intel/intel-xpu-backend-for-triton/issues/1637

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ def TritonIntelGPUAccelerateMatmul
2727
];
2828
}
2929

30+
def TritonIntelGPUCoalesce
31+
: Pass<"tritonintelgpu-coalesce", "mlir::ModuleOp"> {
32+
let summary = "Intel Coalesce";
33+
34+
let description = [{
35+
The pass analyses loads/stores with type `tensor<tt.ptr<>>` or
36+
`tt.ptr<tensor<>>` and replaces the layouts of these operations with
37+
coalesced layouts, i.e. cache friendly access patterns.
38+
Layout conversions are inserted before and after the load/store op
39+
to maintain consistency with the rest of the program.
40+
}];
41+
42+
let dependentDialects = ["mlir::triton::TritonDialect",
43+
"mlir::triton::gpu::TritonGPUDialect"];
44+
}
45+
3046
def TritonIntelGPUDistributeToWarps
3147
: Pass<"tritonintelgpu-distribute-to-warps", "mlir::ModuleOp"> {
3248
let summary = "distribute the thread block workload to the warps";

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,12 @@ class MakeTensorPtrOpAxisInfoVisitor final
10101010
getAxisInfo(triton::MakeTensorPtrOp op,
10111011
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
10121012
LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op);
1013-
assert(op.getShape().size() == 2 && operands.size() == 7 &&
1014-
"MakeTensorPtrOp should have 2D shape");
1013+
1014+
// TODO: Extend to higher dimension tensor pointers.
1015+
if (op.getShape().size() != 2)
1016+
return AxisInfo();
1017+
1018+
assert(operands.size() == 7 && "MakeTensorPtrOp should have 2D shape");
10151019

10161020
AxisInfo ptrInfo = operands[0]->getValue();
10171021
AxisInfo shapeInfo0 = operands[1]->getValue();
@@ -1344,7 +1348,7 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
13441348
} else {
13451349
curAxisInfo = axisInfo;
13461350
}
1347-
(*axisInfoMap)[value] = curAxisInfo;
1351+
(*axisInfoMap)[value] = std::move(curAxisInfo);
13481352
};
13491353
funcOp.walk([&](Operation *op) {
13501354
for (auto value : op->getResults()) {

0 commit comments

Comments
 (0)