Skip to content

Commit cf61339

Browse files
authored
Merge branch 'main' into grouped_tensor_python
2 parents 40c619e + 6a34b65 commit cf61339

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ jobs:
9292
options: --user root
9393
steps:
9494
- name: 'Dependencies'
95-
run: pip install pybind11[global]
95+
run: pip install cmake==3.21.0 pybind11[global]
9696
- name: 'Checkout'
9797
uses: actions/checkout@v3
9898
with:
@@ -144,7 +144,7 @@ jobs:
144144
- name: 'Dependencies'
145145
run: |
146146
docker exec builder bash -c '\
147-
pip install pybind11[global] einops onnxscript && \
147+
pip install cmake==3.21.0 pybind11[global] einops onnxscript && \
148148
pip install torch --no-cache-dir --index-url https://download.pytorch.org/whl/cu130
149149
'
150150
- name: 'Build'

examples/jax/encoder/test_model_parallel_encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# See LICENSE for license information.
44
"""Encoder training on multi-GPU with tesnor parallelism"""
55
import argparse
6+
import os
67
import unittest
78
from functools import partial
89

@@ -489,6 +490,9 @@ class TestEncoder(unittest.TestCase):
489490

490491
def setUp(self):
491492
"""Run 5 epochs for testing"""
493+
# TODO(jberchtold): Remove once fused attention from cuDNN supports determinism on Blackwell
494+
if "NVTE_FUSED_ATTN" not in os.environ:
495+
os.environ["NVTE_FUSED_ATTN"] = "0"
492496
self.args = encoder_parser(["--epochs", "5"])
493497

494498
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")

transformer_engine/common/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,24 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
232232
target_include_directories(transformer_engine PUBLIC
233233
"${CMAKE_CURRENT_SOURCE_DIR}/include")
234234

235-
# CUTLASS kernels require SM90a and cause hang in debug build
235+
# Grouped GEMM kernels require SM90a
236236
set_property(
237237
SOURCE gemm/cutlass_grouped_gemm.cu
238238
APPEND
239239
PROPERTY
240-
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0")
240+
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a")
241+
242+
# CUTLASS kernels could cause hang in debug build
243+
set(CUTLASS_KERNEL_SOURCES
244+
gemm/cutlass_grouped_gemm.cu
245+
hadamard_transform/group_hadamard_transform_cast_fusion.cu
246+
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
247+
hadamard_transform/hadamard_transform_cast_fusion.cu)
248+
set_property(
249+
SOURCE ${CUTLASS_KERNEL_SOURCES}
250+
APPEND
251+
PROPERTY
252+
COMPILE_OPTIONS "-g0;-dopt=on")
241253

242254
# Configure dependencies
243255
target_link_libraries(transformer_engine PUBLIC

0 commit comments

Comments
 (0)