Skip to content

Commit 290cf4e

Browse files
Remove advanced path (#5205)
Advanced path is no longer functional for flash attention. Closes #3286 Signed-off-by: Whitney Tsang <[email protected]>
1 parent 4249232 commit 290cf4e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+64
-7118
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ jobs:
342342
07-extern-functions
343343
08-grouped-gemm
344344
10-experimental-block-pointer
345-
10i-experimental-block-pointer
346345
EOF
347346
348347
- name: Run Tutorials

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -440,49 +440,25 @@ def forward(ctx, q, k, v, causal, sm_scale):
440440
assert Lq == Lk and Lk == Lv
441441
assert Lk in {16, 32, 64, 128}
442442
o = torch.empty_like(q)
443-
BLOCK_M = 128
444-
BLOCK_N = 64
445-
num_stages = 3
446-
num_warps = 8 if Lq == 64 else 16
447443
stage = 3 if causal else 1
448444
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
449445
n_ctx = q.shape[2]
450446
if n_ctx <= 512:
451447
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), 1, q.shape[0] * q.shape[1])
452448
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
453449

454-
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':
455-
# default pipeline
456-
_attention.tune_attn_fwd[grid]( # pylint: disable=unsubscriptable-object
457-
q, k, v, sm_scale, M, o, #
458-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
459-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
460-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
461-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
462-
q.shape[0], q.shape[1], #
463-
N_CTX=q.shape[2], #
464-
BLOCK_DMODEL=Lk, #
465-
STAGE=stage, #
466-
split_barriers_scope='None', # possible scope value: 'Subgroup','Workgroup'
467-
)
468-
else:
469-
_attention.attn_fwd[grid]( # pylint: disable=unsubscriptable-object
470-
q, k, v, sm_scale, M, o, #
471-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
472-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
473-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
474-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
475-
q.shape[0], q.shape[1], #
476-
N_CTX=q.shape[2], #
477-
BLOCK_M=BLOCK_M, #
478-
BLOCK_N=BLOCK_N, #
479-
BLOCK_DMODEL=Lk, #
480-
STAGE=stage, #
481-
num_warps=num_warps, #
482-
num_stages=num_stages, #
483-
grf_mode='large', #
484-
advanced_path=True, #
485-
)
450+
_attention.tune_attn_fwd[grid]( # pylint: disable=unsubscriptable-object
451+
q, k, v, sm_scale, M, o, #
452+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
453+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
454+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
455+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
456+
q.shape[0], q.shape[1], #
457+
N_CTX=q.shape[2], #
458+
BLOCK_DMODEL=Lk, #
459+
STAGE=stage, #
460+
split_barriers_scope='None', # possible scope value: 'Subgroup','Workgroup'
461+
)
486462

487463
ctx.save_for_backward(q, k, v, o, M)
488464
ctx.sm_scale = sm_scale

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "intel/include/TritonGENToLLVM/Passes.h"
99
#include "intel/include/TritonGENToSPIRV/Passes.h"
1010
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
11-
#include "intel/include/TritonToTritonGPUWarp/Passes.h"
1211

1312
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
1413
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
@@ -89,7 +88,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8988
mlir::test::registerTestAMDGPUMembarPass();
9089
mlir::test::registerTestTritonAMDGPURangeAnalysis();
9190
mlir::triton::registerConvertTritonToTritonGPUPass();
92-
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
9391
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
9492
mlir::triton::intel::registerTritonIntelRemoveMasks();
9593
mlir::triton::registerRelayoutTritonGPUPass();

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,12 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4444
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
4545
"TRITON_F32_DEFAULT",
4646
"TRITON_PREFER_TMEM_16x256_LAYOUT",
47-
"TRITON_INTEL_ADVANCED_PATH",
4847
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
49-
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
5048
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
5149
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
52-
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
53-
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5450
"TRITON_INTEL_FAST_MATH",
5551
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
5652
"TRITON_INTEL_PREDICATED",
57-
"TRITON_INTEL_REDUCE_TRANSPOSE",
5853
// clang-format on
5954
};
6055

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
105105
}
106106

107107
bool ReduceOpHelper::isWarpSynchronous() {
108-
// FIXME: In the default path tensors will always have a layout. Tensors do
109-
// not have a layout only in the advanced path. We need to find a workaround
110-
// in order to remove this change.
111-
if (!srcEncoding)
112-
return true;
113108
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
114109
}
115110

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,8 +1426,6 @@ void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
14261426
LogicalResult
14271427
SliceEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
14281428
unsigned dim, DistributedEncodingTrait parent) {
1429-
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"))
1430-
return success();
14311429
unsigned rank = cast<LayoutEncodingTrait>(parent).getRank();
14321430
if (rank <= 1)
14331431
return emitError() << "parent layout must have at least rank >= 2";
@@ -2558,13 +2556,6 @@ LogicalResult DotOperandEncodingAttr::verify(
25582556
return success();
25592557
}
25602558

2561-
if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
2562-
if (kWidth != 0)
2563-
return emitError() << "ttg.dot_op kWidth parameter is not supported "
2564-
"when the parent is a warp layout";
2565-
return success();
2566-
}
2567-
25682559
if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
25692560
if (kWidth != 0)
25702561
return emitError() << "ttg.dot_op kWidth parameter is not supported "
@@ -2597,9 +2588,6 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
25972588
} else if (auto linearAttr = mlir::dyn_cast<LinearEncodingAttr>(attr)) {
25982589
os << "linear";
25992590
return AliasResult::FinalAlias;
2600-
} else if (auto warpAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(attr)) {
2601-
os << "warp";
2602-
return AliasResult::FinalAlias;
26032591
} /* else if (auto sliceAttr = dyn_cast<SliceEncodingAttr>(attr)) {
26042592
os << "slice";
26052593
return AliasResult::FinalAlias;
@@ -3298,8 +3286,6 @@ struct TritonGPUVerifyTensorLayoutInterface
32983286
if (!distr)
32993287
return makeErr()
33003288
<< "Non-distributed layout is not allowed in tensor type.";
3301-
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"))
3302-
return success();
33033289
auto rank = distr.getRepOrder().size();
33043290
if (rank != rankedTy.getRank())
33053291
return makeErr() << "Layout has rank " << rank

python/triton/knobs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,7 @@ class intel_knobs(base_knobs):
548548
dump_shader_info: env_bool = env_bool("TRITON_INTEL_ENABLE_IGC_SHADER_DUMP", False)
549549
gen_native_code: env_bool = env_bool("TRITON_XPU_GEN_NATIVE_CODE", False)
550550
tile_load_ll: env_bool = env_bool("TRITON_XPU_ENABLE_TILE_LOAD_LINEAR_LAYOUT", True)
551-
advanced_path: env_bool = env_bool("TRITON_INTEL_ADVANCED_PATH", False)
552551
opt_reduction_locality: env_bool = env_bool("TRITON_INTEL_OPTIMIZE_REDUCTION_LOCALITY", False)
553-
reduce_transpose: env_bool = env_bool("TRITON_INTEL_REDUCE_TRANSPOSE", False)
554552
disable_igc_opt: env_bool = env_bool("TRITON_INTEL_DISABLE_IGC_OPT", False)
555553

556554
dump_spirv_kernel_args: env_opt_str = env_opt_str("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")

python/tutorials/10-experimental-block-pointer.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,11 @@
9090
# Final Result
9191
# ------------
9292

93-
import os
94-
9593
import torch
9694

9795
import triton
9896
import triton.language as tl
9997

100-
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0'
101-
10298

10399
@triton.autotune(
104100
configs=[
@@ -107,18 +103,14 @@
107103
num_stages=s, num_warps=32) for s in [1, 2, 3]
108104
] + [
109105
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
110-
num_stages=s, num_warps=w)
111-
for s in [2, 3, 4]
112-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
106+
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('large', 32), ('small', 64)])
113107
] + [
114108
triton.Config(
115109
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
116110
num_stages=s, num_warps=32) for s in [2]
117111
] + [
118112
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
119-
num_stages=s, num_warps=w)
120-
for s in [2, 3]
121-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
113+
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)])
122114
],
123115
key=['M', 'N', 'K'],
124116
)
@@ -349,9 +341,7 @@ def matmul(a, b, accum_dtype, res_dtype):
349341
FP8_TYPES = [(torch.float8_e4m3fn, torch.float32, torch.float16)]
350342

351343
torch.manual_seed(0)
352-
for dtype, accum_dtype, res_dtype in FP16_TYPES + FP32_TYPES + INT8_TYPES + (FP8_TYPES if os.getenv(
353-
'TRITON_INTEL_ADVANCED_PATH', '0') == '0' else []):
354-
344+
for dtype, accum_dtype, res_dtype in FP16_TYPES + FP32_TYPES + INT8_TYPES + FP8_TYPES:
355345
for shape in [(512, 512), (4, 512, 512)]:
356346
assert shape[-1] == shape[-2], "Only square matrices are supported"
357347
if dtype.is_floating_point:

python/tutorials/10i-experimental-block-pointer.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

scripts/run_tutorial.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ def run_tutorial(path: pathlib.Path) -> float:
5858
if not spec or not spec.loader:
5959
raise AssertionError(f'Failed to load module from {path}')
6060
module = importlib.util.module_from_spec(spec)
61-
# Set __file__ to the absolute name, a workaround for 10i-experimental-block-pointer, which
62-
# uses dirname of its location to find 10-experimental-block-pointer.
63-
module.__file__ = path.resolve().as_posix()
6461
# Reset sys.argv because some tutorials, such as 09, parse their command line arguments.
6562
sys.argv = [str(path)]
6663
start_time = datetime.datetime.now()

0 commit comments

Comments
 (0)