Skip to content

Commit ae77f77

Browse files
committed
Merge branch 'main' into gregory/windows-support
2 parents 62774dc + a3adef5 commit ae77f77

File tree

57 files changed

+1247
-1203
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1247
-1203
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ runs:
8282
uses: ./.github/actions/load
8383
env:
8484
# Increase this value to reset cache
85-
CACHE_NUMBER: 11
85+
CACHE_NUMBER: 12
8686
with:
8787
path: pytorch
8888
key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8321eec009c8c79145ebccd51fdfc336e5f8b848
1+
487873f7cafeb0fd390eaefe40496b804bceabbd

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ name: Integration Tests
1010
on:
1111
workflow_dispatch:
1212
pull_request:
13-
# You can name your branch dev-foo to get CI runs.
14-
branches: [main, 'dev-**']
13+
branches-ignore: ['llvm-**']
1514
merge_group:
1615
branches: [main, 'dev-**']
1716
types: [checks_requested]

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ name: Integration Tests
99
on:
1010
workflow_dispatch:
1111
pull_request:
12-
# You can name your branch dev-foo to get CI runs.
13-
branches: [main, 'dev-**']
12+
branches-ignore: ['llvm-**']
1413
merge_group:
1514
branches: [main, 'dev-**']
1615
types: [checks_requested]

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ python/triton/language/extra
3434
# Proton
3535
python/triton/profiler
3636

37+
# Instrumentation
38+
python/triton/instrumentation
39+
3740
# Python caches
3841
__pycache__/
3942
*.py[cod]

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(q, k, v, causal, sm_scale):
171171
assert Lk in {16, 32, 64, 128}
172172
o = torch.empty_like(q, dtype=torch.float32)
173173
BLOCK_M = 128
174-
BLOCK_N = 64 if Lk <= 64 else 32
174+
BLOCK_N = 64
175175
num_stages = 3
176176
num_warps = 8 if Lq == 64 else 16
177177
stage = 3 if causal else 1
@@ -205,7 +205,8 @@ def forward(q, k, v, causal, sm_scale):
205205
BLOCK_DMODEL=Lk, #
206206
STAGE=stage, #
207207
num_warps=num_warps, #
208-
num_stages=num_stages #
208+
num_stages=num_stages, #
209+
grf_mode='large', #
209210
)
210211
return o
211212

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9191
mlir::registerTritonAMDGPUStreamPipeline();
9292
mlir::registerTritonAMDGPUStreamPipelineV2();
9393
mlir::registerTritonAMDGPUCanonicalizePointers();
94-
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
95-
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9694
#endif
9795

9896
// TODO: register Triton & TritonGPU passes

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8888
// encoding not available
8989
return resultVals;
9090
Attribute baseEncoding = encoding;
91-
if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92-
// TODO: this logic seems incorrect for mfma layout. Skip for now.
93-
// We saw mismatches for some flash-attention tests on AMD backend.
94-
// Note that this logic works for sliced layout whose parent is
91+
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
92+
isa<AMDWmmaEncodingAttr>(baseEncoding))
93+
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
94+
// now. We saw mismatches for some flash-attention and dot tests on AMD
95+
// backend. Note that this logic works for sliced layout whose parent is
9596
// mfma layout. Therefore, this is not combined with the following check.
9697
return resultVals;
9798
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,15 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30-
struct BackendCallbacks {
31-
/**
32-
* A backend-specific callback for appending auxiliary data during
33-
* `LocalStoreOp` conversion.
34-
*
35-
* @param[in] op The reference to the re-written `LocalStoreOp`.
36-
* @param[in] count The number of issued LLVM instructions.
37-
* @param[in] type The input type of issued LLVM instructions.
38-
*/
39-
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40-
Type llvmOpType)>
41-
localStoreOpConversion = nullptr;
42-
};
43-
4430
void populateElementwiseOpToLLVMPatterns(
4531
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
4632
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
4733
PatternBenefit benefit);
4834

49-
// The given callback is invoked at the end of a successful rewrite. The
50-
// callback receives 1) the current source op, 2) the number of issued LLVM
51-
// instructions and 3) their input types. Each MLIR backend can provide a
52-
// callback and, thus, handle backend-specific behaviors.
53-
void populateMemoryOpToLLVMPattern(
54-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55-
RewritePatternSet &patterns, PatternBenefit benefit,
56-
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
35+
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36+
const TargetInfoBase &targetInfo,
37+
RewritePatternSet &patterns,
38+
PatternBenefit benefit);
5739

5840
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
5941
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(
1370-
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371-
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373-
std::pair<size_t, Type> *const llvmOpCount = nullptr);
1369+
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370+
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371+
Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter,
1373+
const TargetInfoBase &target);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

0 commit comments

Comments
 (0)