File tree Expand file tree Collapse file tree 3 files changed +20
-4
lines changed
transformer_engine/common Expand file tree Collapse file tree 3 files changed +20
-4
lines changed Original file line number Diff line number Diff line change 9292 options : --user root
9393 steps :
9494 - name : ' Dependencies'
95- run : pip install pybind11[global]
95+ run : pip install cmake==3.21.0 pybind11[global]
9696 - name : ' Checkout'
9797 uses : actions/checkout@v3
9898 with :
@@ -144,7 +144,7 @@ jobs:
144144 - name : ' Dependencies'
145145 run : |
146146 docker exec builder bash -c '\
147- pip install pybind11[global] einops onnxscript && \
147+ pip install cmake==3.21.0 pybind11[global] einops onnxscript && \
148148 pip install torch --no-cache-dir --index-url https://download.pytorch.org/whl/cu130
149149 '
150150 - name : ' Build'
Original file line number Diff line number Diff line change 33# See LICENSE for license information.
44"""Encoder training on multi-GPU with tesnor parallelism"""
55import argparse
6+ import os
67import unittest
78from functools import partial
89
@@ -489,6 +490,9 @@ class TestEncoder(unittest.TestCase):
489490
490491 def setUp (self ):
491492 """Run 5 epochs for testing"""
493+ # TODO(jberchtold): Remove once fused attention from cuDNN supports determinism on Blackwell
494+ if "NVTE_FUSED_ATTN" not in os .environ :
495+ os .environ ["NVTE_FUSED_ATTN" ] = "0"
492496 self .args = encoder_parser (["--epochs" , "5" ])
493497
494498 @unittest .skipIf (not is_bf16_supported (), "Device compute capability 8.0+ is required for BF16" )
Original file line number Diff line number Diff line change @@ -232,12 +232,24 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
232232target_include_directories (transformer_engine PUBLIC
233233 "${CMAKE_CURRENT_SOURCE_DIR} /include" )
234234
235- # CUTLASS kernels require SM90a and cause hang in debug build
235+ # Grouped GEMM kernels require SM90a
236236set_property (
237237 SOURCE gemm/cutlass_grouped_gemm.cu
238238 APPEND
239239 PROPERTY
240- COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0" )
240+ COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a" )
241+
242+ # CUTLASS kernels could cause hang in debug build
243+ set (CUTLASS_KERNEL_SOURCES
244+ gemm/cutlass_grouped_gemm.cu
245+ hadamard_transform/group_hadamard_transform_cast_fusion.cu
246+ hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
247+ hadamard_transform/hadamard_transform_cast_fusion.cu)
248+ set_property (
249+ SOURCE ${CUTLASS_KERNEL_SOURCES}
250+ APPEND
251+ PROPERTY
252+ COMPILE_OPTIONS "-g0;-dopt=on" )
241253
242254# Configure dependencies
243255target_link_libraries (transformer_engine PUBLIC
You can’t perform that action at this time.
0 commit comments