Skip to content

Commit 1487600

Browse files
Merge branch 'main' into ksuvorov/fix800
2 parents 35b6615 + b79ceaa commit 1487600

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
213213

214214
function_events = prof.events()
215215

216-
functions = []
216+
all_functions = []
217217
if isinstance(kernel_name, str):
218218
kernel_name = [kernel_name]
219219
for ker_name in kernel_name:
220-
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
220+
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
221+
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
222+
all_functions.append(functions)
221223
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
222224

223-
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
224225
# Make the time to the milliseconds.
225-
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
226+
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
227+
dtype=torch.float)
226228
return _summarize_statistics(times, quantiles, return_mode)
227229

228230

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,8 @@ def benchmark(M, N, K, provider):
293293
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
294294

295295
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
296-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
297-
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
298-
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
296+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
297+
quantiles=quantiles, kernel_name='stream_k_gemm_run')
299298
else:
300299
raise NotImplementedError(f'Unsupported provider {provider}')
301300

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
105105
AxisInfo
106106
getAxisInfo(OpTy op,
107107
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
108-
auto lhsInfo = operands[0]->getValue();
109-
auto rhsInfo = operands[1]->getValue();
108+
const auto &lhsInfo = operands[0]->getValue();
109+
const auto &rhsInfo = operands[1]->getValue();
110110
auto rank = lhsInfo.getRank();
111111
assert(operands.size() == 2 && "Expected two operands");
112112
AxisInfo::DimVectorT contiguity;
@@ -658,8 +658,8 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
658658
return AxisInfo();
659659
auto shape = resTy.getShape();
660660
short rank = resTy.getRank();
661-
auto lhsInfo = operands[0]->getValue();
662-
auto rhsInfo = operands[1]->getValue();
661+
const auto &lhsInfo = operands[0]->getValue();
662+
const auto &rhsInfo = operands[1]->getValue();
663663

664664
AxisInfo::DimVectorT contiguity, divisibility, constancy;
665665
std::optional<int64_t> constantValue;
@@ -782,8 +782,8 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
782782
getAxisInfo(OpTy op,
783783
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
784784
auto condConstancy = operands[0]->getValue().getConstancy();
785-
auto lhsInfo = operands[1]->getValue();
786-
auto rhsInfo = operands[2]->getValue();
785+
const auto &lhsInfo = operands[1]->getValue();
786+
const auto &rhsInfo = operands[2]->getValue();
787787
auto rank = lhsInfo.getRank();
788788

789789
AxisInfo::DimVectorT contiguity, divisibility, constancy;
@@ -967,8 +967,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
967967
AxisInfo
968968
getAxisInfo(OpTy op,
969969
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
970-
auto lhsInfo = operands[0]->getValue();
971-
auto rhsInfo = operands[1]->getValue();
970+
const auto &lhsInfo = operands[0]->getValue();
971+
const auto &rhsInfo = operands[1]->getValue();
972972
auto rank = lhsInfo.getRank();
973973
std::optional<int64_t> constantValue;
974974
if (lhsInfo.getConstantValue().has_value() &&
@@ -1140,8 +1140,8 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11401140

11411141
void AxisInfoAnalysis::visitForOpInductionVar(
11421142
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1143-
auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
1144-
auto step = getLatticeElementFor(op, op.getStep())->getValue();
1143+
const auto &lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
1144+
const auto &step = getLatticeElementFor(op, op.getStep())->getValue();
11451145

11461146
AxisInfo::DimVectorT knownContiguity(1, 1);
11471147
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1337,7 +1337,7 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
13371337
return;
13381338
auto *axisInfoMap = getFuncData(funcOp);
13391339
auto updateAxisInfoMap = [&](Value value) {
1340-
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1340+
const auto &axisInfo = analysis->getLatticeElement(value)->getValue();
13411341
AxisInfo curAxisInfo;
13421342
if (axisInfoMap->count(value)) {
13431343
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));

0 commit comments

Comments
 (0)