Skip to content

Commit 7f153c3

Browse files
ZelboKdanial javady
authored andcommitted
[PROTON][AMD] Fix failing proton tests for AMD GPUs (#8763)
Fixes upgrade to rocm7 breaking proton tests alongside implementing CircularStoreOp for gmem <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - [x] I am not making a trivial change, such as fixing a typo in a comment. - [ ] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: danial javady <[email protected]>
1 parent f7a199d commit 7f153c3

File tree

4 files changed

+15
-13
lines changed

4 files changed

+15
-13
lines changed

third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h"
33
#include "Conversion/ProtonGPUToLLVM/Utility.h"
44
#include "Dialect/ProtonGPU/IR/Dialect.h"
5+
#include "amd/lib/TritonAMDGPUToLLVM/Utility.h"
56
#include "mlir/Conversion/LLVMCommon/Pattern.h"
67
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
8+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
79
#include "mlir/IR/PatternMatch.h"
810
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
911

@@ -37,7 +39,8 @@ struct CircularStoreOpConversion
3739
// TODO(crobeck): see what buffer ops performance looks like here for
3840
// global mem (address space 1) compared to predicated ops to shared
3941
// memory
40-
llvm::report_fatal_error("unimplemented");
42+
mlir::LLVM::AMD::llStore(rewriter, loc, dataPack.ptr, dataPack.record,
43+
dataPack.isWriter);
4144
} else if (addrSpace == 3) {
4245
targetInfo.getTritonTargetInfo().storeDShared(
4346
rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record,

third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -350,19 +350,22 @@ void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback(
350350
// data on stop
351351
maxCorrelationId =
352352
std::max<uint64_t>(maxCorrelationId, record->correlation_id);
353-
// TODO(Keren): Roctracer doesn't support cuda graph yet.
353+
bool hasCorrelation =
354+
correlation.corrIdToExternId.contain(record->correlation_id);
354355
auto externId =
355-
correlation.corrIdToExternId.contain(record->correlation_id)
356+
hasCorrelation
356357
? correlation.corrIdToExternId.at(record->correlation_id).first
357358
: Scope::DummyScopeId;
358359
auto isAPI = correlation.apiExternIds.contain(externId);
359360
bool isGraph = pImpl->CorrIdToIsHipGraph.contain(record->correlation_id);
360-
processActivity(correlation.corrIdToExternId, correlation.apiExternIds,
361-
externId, dataSet, record, isAPI, isGraph);
362-
// Track correlation ids from the same stream and erase those <
363-
// correlationId
364-
correlation.corrIdToExternId.erase(record->correlation_id);
365-
correlation.apiExternIds.erase(externId);
361+
if (hasCorrelation) {
362+
processActivity(correlation.corrIdToExternId, correlation.apiExternIds,
363+
externId, dataSet, record, isAPI, isGraph);
364+
// Track correlation ids from the same stream and erase those <
365+
// correlationId
366+
} else {
367+
correlation.apiExternIds.erase(externId);
368+
}
366369
roctracer::getNextRecord<true>(record, &record);
367370
}
368371
correlation.complete(maxCorrelationId);

third_party/proton/test/test_instrumentation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
is_cuda,
1616
is_hip,
1717
is_hip_cdna2,
18-
is_hip_cdna4,
1918
supports_tma,
2019
supports_ws,
2120
)
@@ -644,7 +643,6 @@ def foo(x, y, size: tl.constexpr):
644643
assert trace_events[-1]["args"]["call_stack"][-2] == "test"
645644

646645

647-
@pytest.mark.skipif(is_hip_cdna4(), reason="nondeterministic failure")
648646
def test_globaltime(tmp_path: pathlib.Path):
649647
temp_file = tmp_path / "test_globaltime.chrome_trace"
650648
mode = proton.mode.Default(
@@ -760,7 +758,6 @@ def session_kernel_time(session_name: str) -> Tuple[int, int]:
760758
assert session1_loop_time / session0_loop_time < loop_threshold, "Loop kernel overhead too high"
761759

762760

763-
@pytest.mark.skipif(is_hip(), reason="not implemented yet")
764761
def test_gmem_buffer(tmp_path: pathlib.Path):
765762

766763
@triton.jit

third_party/proton/test/test_profile.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def foo(x, y):
8080
assert data[0]["children"][1]["frame"]["name"] == "test2"
8181

8282

83-
@pytest.mark.skipif(is_hip(), reason="Currently broken after updating to ROCm 7")
8483
def test_cudagraph(tmp_path: pathlib.Path, device: str):
8584
if is_xpu():
8685
pytest.skip("xpu doesn't support cudagraph; FIXME: double check")

0 commit comments

Comments
 (0)