-
Notifications
You must be signed in to change notification settings - Fork 23
IFU dev v2.6 #374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
IFU dev v2.6 #374
Conversation
Signed-off-by: Przemek Tredak <[email protected]>
* tests drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move dir Signed-off-by: Pawel Gadzinski <[email protected]> * tests fox Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Przemek Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemek Tredak <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix README render on PyPI Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update README.rst Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Use anonymous hyperlink for duplicate. Fix indent. Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Check tensor-recipe compatibility Signed-off-by: Evgeny Tsykunov <[email protected]> * Tensor class in recipe, checking for *Base Signed-off-by: Evgeny Tsykunov <[email protected]> * Extend recipe __repr__ with recipe_type Signed-off-by: Evgeny Tsykunov <[email protected]> * Warn about recipe change Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enable dynamic recipe change: clear fp8 workspace Signed-off-by: Evgeny Tsykunov <[email protected]> * TE 1.x checkpoint compatibility Signed-off-by: Evgeny Tsykunov <[email protected]> * Disable warning for recipe wrappers Signed-off-by: Evgeny Tsykunov <[email protected]> * Test recipe change Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use QuantizedTensorBase Signed-off-by: Evgeny Tsykunov <[email protected]> * Fix circular import Signed-off-by: Evgeny Tsykunov <[email protected]> * Revert previous circular import fix Signed-off-by: Evgeny Tsykunov <[email protected]> * Fix pytorch imports in common Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Let quantizer know about the recipe Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix imports Signed-off-by: Evgeny Tsykunov <[email protected]> --------- Signed-off-by: Evgeny Tsykunov <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix split_overlap_rs aggregate=True chunk offset calculation Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add unit test for aggregate=True Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix unit test Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Guyue Huang <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…te (#1799) * Use an empty torch tensor to indicate no fp8 information in extra_state Signed-off-by: Peter St. John <[email protected]> * Add huggingface from_pretrained / save_pretrained tests Adds integration tests to ensure models containing TransformerLayer objects can be saved and loaded using the from_pretrained and save_pretrained methods. Signed-off-by: Peter St. John <[email protected]> --------- Signed-off-by: Peter St. John <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…n (#1611) * docs drop Signed-off-by: Pawel Gadzinski <[email protected]> * a Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * Update docs/debug/1_getting_started.rst Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Paweł Gadziński <[email protected]> * Update docs/debug/1_getting_started.rst Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Paweł Gadziński <[email protected]> * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * fix imgs Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Paweł Gadziński <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]>
add docstring for CP Signed-off-by: Charlene Yang <[email protected]>
* Add missing docs for C API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Grammar, typos, copy-paste errors Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * remove contiguous word Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Better wording Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* fix model parallel encoder to be properly sharded Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
fix saved_tensors Signed-off-by: Pawel Gadzinski <[email protected]>
Fix incorrectly skipped test_quantize_dbias tests Signed-off-by: Jeremy Berchtold <[email protected]>
Remove comm_gemm_overlap docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Build support for cuda 13 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix build for cudnn 8.9*; cuda 12.1 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * readd include Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…rity (#1811) Make primitive names more granular for better disabling granularity Signed-off-by: Jeremy Berchtold <[email protected]>
Document all recipes Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…#1804) Activation ops support fusing backward pass with quantize Signed-off-by: Tim Moon <[email protected]>
* Fix env variable name in test.sh scripts to properly test pure-JAX implementations Signed-off-by: Jeremy Berchtold <[email protected]> * Update test scripts to use pure-JAX impl in encoder test_custom_call_compute.py already uses pure-JAX impl as reference so testing the pure-JAX impl against itself would be redundant. The encoder tests have their own implementation so testing the pure-JAX impl of primitives is still useful. Signed-off-by: Jeremy Berchtold <[email protected]> * Update qa/L0_jax_unittest/test.sh Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]>
* Modify the test cases Signed-off-by: Przemek Tredak <[email protected]> * Make the tests reproducible on different machines Signed-off-by: Przemek Tredak <[email protected]> * Fixed the cache of the gamma_in_weight_dtype setting Signed-off-by: Przemek Tredak <[email protected]> * Reinstate the tests Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More verbose code and comments Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* added conda installation Signed-off-by: Santosh Bhavani <[email protected]> * fix for pypi Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Santosh Bhavani <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix single FW build with multi FW available Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Some fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * sug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…ion (#1822) Update jax_scaled_masked_softmax to match TE kernel implementation Signed-off-by: Jeremy Berchtold <[email protected]>
* fp8 gemm with direct quant Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
* removes unnecessary reshapes for FP8 GEMM * use nn.jax.scaled_matmul Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
…needed (#1817) * Linear op avoids saving input tensor if weight grad is not needed Signed-off-by: Tim Moon <[email protected]> * Linear op forward avoids producing quantized tensors with unnecessary usages Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Avoid unnecessary usages in fused linear ops Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]>
…#1813) * Changed the Tensor allocation strategy Signed-off-by: Przemek Tredak <[email protected]> * Fixes Signed-off-by: Przemek Tredak <[email protected]> * Disable debug flag Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the double free error Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fixed pyTorch recipe extension Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Hide TensorAllocator and fix the usage in LayerNorm Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaning Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fix permutation Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Support SWA in CP Ring Attn THD striped sharding Signed-off-by: Hua Huang <[email protected]> * Add some comments; move check to _FusedAttnCPWithP2PHelper.check_supported() Signed-off-by: Hua Huang <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Remove unused check Signed-off-by: Hua Huang <[email protected]> --------- Signed-off-by: Hua Huang <[email protected]>
Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Quantizer update Signed-off-by: Evgeny Tsykunov <[email protected]> * Update import Signed-off-by: Evgeny <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Introduce _update_weight_quantizers and _get_weight_tensors/_get_weight_quantizers Signed-off-by: Evgeny <[email protected]> * Add test Signed-off-by: Evgeny <[email protected]> * Move _quantizer to the QuantizedTensorBase Signed-off-by: Evgeny <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix import Signed-off-by: Evgeny Tsykunov <[email protected]> --------- Signed-off-by: Evgeny Tsykunov <[email protected]> Signed-off-by: Evgeny <[email protected]> Co-authored-by: Evgeny Tsykunov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]>
…talled (#1834) * Add warning for multi framework case Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Alp Dener <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Alp Dener <[email protected]>
* remove unnecessary padding Signed-off-by: Phuong Nguyen <[email protected]> * adapt the test_distributed_layernorm byte count Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
* optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid reindexing from python side Signed-off-by: Charlene Yang <[email protected]> * rename variable from previous commit Signed-off-by: Charlene Yang <[email protected]> * minor fix Signed-off-by: Charlene Yang <[email protected]> * minor fix Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* update cudnn-frontend to 1.13.0 Signed-off-by: Charlene Yang <[email protected]> * disable 9.11 for a bug Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix selection logic Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…it test (#1967) * set precision=HIGHEST for the ref_grouped_gemm impl in the unit test Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
* enable cudnn norm tests Signed-off-by: Phuong Nguyen <[email protected]> * exclude tests on pre-Hopper Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
Update tolerance of distributed layernorm MLP for FP8 Signed-off-by: Jeremy Berchtold <[email protected]>
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // TODO: remove after rocm supports NV __syncwarp equivalent | ||
| __device__ inline void __syncwarp() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there was some discussions of this before. Are we able to use cooperative groups here with some rewriting to avoid the performance hit of fences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think _syncwarp is a warp-level operation which is independent from cooperative groups (group of blocks which contains warps).
After researching around, it looks like this section is our current rocm best
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think you can use cooperative groups like this:
auto block = cooperative_groups::this_thread_block();
auto wave = cooperative_groups::tiled_partition<64>(block);
then sync just the wave:
wave.sync();
maybe that is more overhead, but if the calling function syncs multiple times it might be faster than the fence.
| #else | ||
| if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> || | ||
| std::is_same_v<InputType, fp4e2m1>){ | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If fp4e2m1 datatype defined on ROCm, HIP_PLATFORM_AMD guard is not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have a fp4_e2m1 defined in rocm but not sure whether they are the same corresponding to NV fp4:
/opt/rocm/include/hip/amd_detail/amd_hip_fp4.h:264:struct __hip_fp4_e2m1 {
That's why I didn't enable it in transformer_engine/common/common.h for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean if it builds and we disable FP4 support elsewhere this method should not be really called with FP4 so we may just have code working correctly w/o extra guards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to include <hip/hip_fp4.h>. It was fine compiling libtransformer_engine.so but failed in pytorch extension so compilation:
c++ -MMD -MF /tmp/tmplbl2x2ld.build-temp/transformer_engine/pytorch/csrc/extensions/bias_hip.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/opt/rocm/include -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/common -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/common/include -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/pytorch/csrc -I/opt/venv/lib/python3.10/site-packages/torch/include -I/opt/venv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/opt/venv/include -I/usr/include/python3.10 -c -c /workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/pytorch/csrc/extensions/bias_hip.cpp -o /tmp/tmplbl2x2ld.build-temp/transformer_engine/pytorch/csrc/extensions/bias_hip.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -O3 -fvisibility=hidden -g0 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=transformer_engine_torch -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++17
In file included from /opt/rocm/include/hip/amd_detail/amd_hip_fp4.h:31,
from /opt/rocm/include/hip/hip_fp4.h:29,
from /workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/common/common_hip.h:28,
from /workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/pytorch/csrc/extensions/bias_hip.cpp:8:
/opt/rocm/include/hip/amd_detail/amd_hip_ocp_types.h:87:2: error: #error "Only supported by HIPCC or GCC >= 13."
87 | #error "Only supported by HIPCC or GCC >= 13."
| ^~~~~
Our CI image has c++ of version 11.4:
c++ (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
| @@ -476,6 +476,8 @@ void rocm_norm_mxfp8_quantize(LaunchParams<ForwardKernelParams> &launch_params) | |||
| } | |||
| #endif | |||
|
|
|||
| bool& use_zero_centered_gamma_in_weight_dtype(); | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should probably be guarded with #ifndef HIP_PLATFORM_AMD. Maybe move it to line 430..433 - before rocm_norm_mxfp8_quantize() which is ROCm addition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks
| @@ -222,6 +225,16 @@ const std::string &include_directory(bool required) { | |||
| } | |||
| #endif // __HIP_PLATFORM_AMD__ | |||
|
|
|||
| int cudart_version() { | |||
| auto get_version = []() -> int { | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What will it return on ROCm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't run. I just moved it inside the HIP_PLATFORM_AMD
| @@ -1,4 +1,6 @@ | |||
| /************************************************************************* | |||
| * This file was modified for portability to AMDGPU | |||
| * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it different form upstream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to make it the same as upstream but run into several compilation failures resulting from hipify "cuda_runtime.h" --> "hip/hip_runtime.h". In fact this cuda_runtime.h is under transformer_engine/common/util. I guess it's a bad naming.
I added a manual guard instead
| #else | ||
| if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> || | ||
| std::is_same_v<InputType, fp4e2m1>){ | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean if it builds and we disable FP4 support elsewhere this method should not be really called with FP4 so we may just have code working correctly w/o extra guards
| // break; | ||
| #ifndef USE_ROCM | ||
| case xla::ffi::DataType::F8E8M0FNU: | ||
| return DType::kFloat8E8M0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TE supports this datatype. Why is it guarded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Thanks
| @functools.lru_cache(maxsize=None) | ||
| def is_fp8_gemm_with_all_layouts_supported() -> bool: | ||
| """Return True if using Blackwell architecture, False otherwise.""" | ||
| compute_capability = get_device_compute_capability() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need ROCm path here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added guard. Thanks
| # See LICENSE for license information. | ||
|
|
||
| [build-system] | ||
| requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since .toml does not support conditions, need update jax[cuda12] with just jax, or make it ROCm specific and request jax[rocm]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks
In fact, will this requires take effect since we are using "--no-build-isolation"?
Description
targeted NV upstream commit: ca7407e on 2025/07/18 based on our rocm dev commit 6bbd03c
Fixes https://github.com/ROCm/frameworks-internal/issues/13729
Type of change
Changes
See NV upstream release doc for upstream changes.
Our IFU conflict resolving are listed in the following commits:
1). common: 4d3ca4d
2). jax extension: 5ce0afd
3). pytorch extension: c9c9126
4). build/installation: 9730903
5). cpp gtests: 51bdbb8
6). pytorch pytests: ba59f81
7). jax pytests: 5842c24
Checklist: