Skip to content

Commit b274588

Browse files
authored
[infra] Fix docker workflow (#343)
1 parent 0213e77 commit b274588

File tree

9 files changed

+48
-34
lines changed

9 files changed

+48
-34
lines changed

.github/workflows/docker-rocm.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
uses: docker/login-action@v2
2929
with:
3030
registry: ghcr.io
31-
username: pytorch-labs
31+
username: meta-pytorch
3232
password: ${{ secrets.TRITONBENCH_ACCESS_TOKEN }}
3333
- name: Build TritonBench nightly docker
3434
run: |
@@ -39,19 +39,19 @@ jobs:
3939
# and it is github.ref_name when triggered by workflow_dispatch
4040
branch_name=${{ github.head_ref || github.ref_name }}
4141
docker build . --build-arg TRITONBENCH_BRANCH="${branch_name}" --build-arg FORCE_DATE="${NIGHTLY_DATE}" \
42-
-f tritonbench-rocm-nightly.dockerfile -t ghcr.io/pytorch-labs/tritonbench:rocm-latest
42+
-f tritonbench-rocm-nightly.dockerfile -t ghcr.io/meta-pytorch/tritonbench:rocm-latest
4343
# Extract pytorch version from the docker
44-
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
44+
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/meta-pytorch/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
4545
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
46-
docker tag ghcr.io/pytorch-labs/tritonbench:rocm-latest ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG}
46+
docker tag ghcr.io/meta-pytorch/tritonbench:rocm-latest ghcr.io/meta-pytorch/tritonbench:rocm-${DOCKER_TAG}
4747
- name: Push docker to remote
4848
if: github.event_name != 'pull_request'
4949
run: |
5050
# Extract pytorch version from the docker
51-
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
51+
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/meta-pytorch/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
5252
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
53-
docker push ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG}
54-
docker push ghcr.io/pytorch-labs/tritonbench:rocm-latest
53+
docker push ghcr.io/meta-pytorch/tritonbench:rocm-${DOCKER_TAG}
54+
docker push ghcr.io/meta-pytorch/tritonbench:rocm-latest
5555
concurrency:
5656
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
5757
cancel-in-progress: true

.github/workflows/docker.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ on:
2020
required: false
2121
env:
2222
CONDA_ENV: "tritonbench"
23-
DOCKER_IMAGE: "ghcr.io/pytorch-labs/tritonbench:latest"
23+
DOCKER_IMAGE: "ghcr.io/meta-pytorch/tritonbench:latest"
2424
SETUP_SCRIPT: "/workspace/setup_instance.sh"
2525

2626
jobs:
@@ -38,7 +38,7 @@ jobs:
3838
uses: docker/login-action@v2
3939
with:
4040
registry: ghcr.io
41-
username: pytorch-labs
41+
username: meta-pytorch
4242
password: ${{ secrets.TRITONBENCH_ACCESS_TOKEN }}
4343
- name: Build TritonBench nightly docker
4444
run: |
@@ -49,19 +49,19 @@ jobs:
4949
# and it is github.ref_name when triggered by workflow_dispatch
5050
branch_name=${{ github.head_ref || github.ref_name }}
5151
docker build . --build-arg TRITONBENCH_BRANCH="${branch_name}" --build-arg FORCE_DATE="${NIGHTLY_DATE}" \
52-
-f tritonbench-nightly.dockerfile -t ghcr.io/pytorch-labs/tritonbench:latest
52+
-f tritonbench-nightly.dockerfile -t ghcr.io/meta-pytorch/tritonbench:latest
5353
# Extract pytorch version from the docker
54-
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
54+
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/meta-pytorch/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
5555
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
56-
docker tag ghcr.io/pytorch-labs/tritonbench:latest ghcr.io/pytorch-labs/tritonbench:${DOCKER_TAG}
56+
docker tag ghcr.io/meta-pytorch/tritonbench:latest ghcr.io/meta-pytorch/tritonbench:${DOCKER_TAG}
5757
- name: Push docker to remote
5858
if: github.event_name != 'pull_request'
5959
run: |
6060
# Extract pytorch version from the docker
61-
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
61+
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/meta-pytorch/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
6262
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
63-
docker push ghcr.io/pytorch-labs/tritonbench:${DOCKER_TAG}
64-
docker push ghcr.io/pytorch-labs/tritonbench:latest
63+
docker push ghcr.io/meta-pytorch/tritonbench:${DOCKER_TAG}
64+
docker push ghcr.io/meta-pytorch/tritonbench:latest
6565
concurrency:
6666
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
6767
cancel-in-progress: true

docker/infra/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ The Infra is a Kubernetes cluster built on top of Google Cloud Platform.
66
## Step 1: Create the cluster and install the ARC Controller
77

88
```
9+
# login ghcr.io so that remote can pull the image
10+
docker login ghcr.io
11+
912
# Get credentials for the cluster so that kubectl could use it
1013
gcloud container clusters get-credentials --location us-central1 tritonbench-h100-cluster
1114

docker/infra/values.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ template:
223223
containers:
224224
- name: runner
225225
# image: ghcr.io/actions/actions-runner:latest
226-
image: ghcr.io/pytorch-labs/tritonbench:latest
226+
image: ghcr.io/meta-pytorch/tritonbench:latest
227227
command: ["sh", "-c", "sudo cp -r /usr/bin/nvidia/* /usr/bin; sudo cp -r /usr/lib/x86_64-linux-gnu/nvidia/* /usr/lib/x86_64-linux-gnu; bash /home/runner/run.sh"]
228228
securityContext:
229229
privileged: true

run.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from tritonbench.operators import load_opbench_by_name
1717
from tritonbench.operators_collection import list_operators_by_collection
18+
from tritonbench.utils.ab_test import compare_ab_results, run_ab_test
1819
from tritonbench.utils.env_utils import is_fbcode
1920
from tritonbench.utils.gpu_utils import gpu_lockdown
2021
from tritonbench.utils.list_operator_details import list_operator_details
@@ -23,7 +24,6 @@
2324

2425
from tritonbench.utils.triton_op import BenchmarkOperatorResult
2526
from tritonbench.utils.tritonparse_utils import tritonparse_init, tritonparse_parse
26-
from tritonbench.utils.ab_test import run_ab_test, compare_ab_results
2727

2828
try:
2929
if is_fbcode():
@@ -34,8 +34,6 @@
3434
usage_report_logger = lambda *args, **kwargs: None
3535

3636

37-
38-
3937
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
4038
if is_loader_op(args.op):
4139
Opbench = get_op_loader_bench_cls_by_name(args.op)
@@ -132,23 +130,26 @@ def run(args: List[str] = []):
132130
# Check if A/B testing mode is enabled
133131
if args.side_a is not None and args.side_b is not None:
134132
# A/B testing mode - only support single operator
135-
assert len(ops) == 1, "A/B testing validation should have caught multiple operators"
133+
assert (
134+
len(ops) == 1
135+
), "A/B testing validation should have caught multiple operators"
136136
op = ops[0]
137137
args.op = op
138-
138+
139139
print("[A/B Testing Mode Enabled]")
140140
print(f"Operator: {op}")
141141
print()
142-
142+
143143
with gpu_lockdown(args.gpu_lockdown):
144144
try:
145145
result_a, result_b = run_ab_test(args, extra_args, _run)
146-
146+
147147
from tritonbench.utils.ab_test import parse_ab_config
148+
148149
config_a_args = parse_ab_config(args.side_a)
149150
config_b_args = parse_ab_config(args.side_b)
150151
compare_ab_results(result_a, result_b, config_a_args, config_b_args)
151-
152+
152153
except Exception as e:
153154
print(f"A/B test failed: {e}")
154155
if not args.bypass_fail:
@@ -166,7 +167,7 @@ def run(args: List[str] = []):
166167
run_in_task(op)
167168
else:
168169
_run(args, extra_args)
169-
170+
170171
tritonparse_parse(args.tritonparse)
171172

172173

tools/python_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
DEFAULT_PYTHON_VERSION = "3.12"
99

1010
PYTHON_VERSION_MAP = {
11-
"3.11": {
11+
"3.11": {
1212
"pytorch_url": "cp311",
13-
},
14-
"3.12": {
13+
},
14+
"3.12": {
1515
"pytorch_url": "cp312",
16-
},
16+
},
1717
}
1818
REPO_DIR = Path(__file__).parent.parent
1919

tritonbench/operators/addmm/operator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
except ModuleNotFoundError:
1212
from .hstu import triton_addmm
1313

14-
from tritonbench.operators.gemm.stream_k import streamk_matmul
14+
try:
15+
from tritonbench.operators.gemm.stream_k import streamk_matmul
16+
except ImportError:
17+
streamk_matmul = None
18+
19+
from tritonbench.operators.gemm import stream_k
1520
from tritonbench.utils.triton_op import (
1621
BenchmarkOperator,
1722
BenchmarkOperatorMetrics,
@@ -94,7 +99,7 @@ def __init__(
9499
def triton_addmm(self, a, mat1, mat2) -> Callable:
95100
return lambda: triton_addmm(a, mat1, mat2)
96101

97-
@register_benchmark()
102+
@register_benchmark(enabled=bool(streamk_matmul))
98103
def streamk_addmm(self, a, mat1, mat2) -> Callable:
99104
return lambda: streamk_matmul(mat1, mat2, bias=a)
100105

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,17 @@ def args(m, n, k):
6565
)
6666
return (a, b)
6767

68-
if hasattr(self, 'external_shapes') and self.external_shapes: # Check for external shapes loaded from input-loader
68+
if (
69+
hasattr(self, "external_shapes") and self.external_shapes
70+
): # Check for external shapes loaded from input-loader
6971
for shape in self.external_shapes:
7072
if len(shape) == 3:
7173
m, n, k = shape
7274
yield args(m, n, k)
7375
else:
74-
logger.warning(f"Skipping invalid shape: {shape}, expected [M, N, K]")
76+
logger.warning(
77+
f"Skipping invalid shape: {shape}, expected [M, N, K]"
78+
)
7579
elif self.extra_args.llama:
7680
for m, n, k, _bias in llama_shapes():
7781
yield args(m, n, k)

tritonbench/utils/triton_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# utils to identify triton versions
22

3-
import triton.language as tl
43
import functools
54
import importlib.util
65

6+
import triton.language as tl
7+
78

89
class AsyncTaskContext:
910
"""Context manager that dispatches to tl.async_task if available, otherwise no-op."""

0 commit comments

Comments
 (0)