Skip to content

Conversation

@wangye805
Copy link
Collaborator

@wangye805 wangye805 commented Nov 19, 2025

Description

targeted NV upstream commit: ca7407e on 2025/07/18 based on our rocm dev commit 6bbd03c
Fixes https://github.com/ROCm/frameworks-internal/issues/13729

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

See NV upstream release doc for upstream changes.
Our IFU conflict resolving are listed in the following commits:
1). common: 4d3ca4d
2). jax extension: 5ce0afd
3). pytorch extension: c9c9126
4). build/installation: 9730903
5). cpp gtests: 51bdbb8
6). pytorch pytests: ba59f81
7). jax pytests: 5842c24

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ptrendx and others added 30 commits May 16, 2025 17:17
Signed-off-by: Przemek Tredak <[email protected]>
* tests drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move dir

Signed-off-by: Pawel Gadzinski <[email protected]>

* tests fox

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemek Tredak <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix README render on PyPI

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update README.rst

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Use anonymous hyperlink for duplicate. Fix indent.

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Check tensor-recipe compatibility

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Tensor class in recipe, checking for *Base

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Extend recipe __repr__ with recipe_type

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Warn about recipe change

Signed-off-by: Evgeny Tsykunov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enable dynamic recipe change: clear fp8 workspace

Signed-off-by: Evgeny Tsykunov <[email protected]>

* TE 1.x checkpoint compatibility

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Disable warning for recipe wrappers

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Test recipe change

Signed-off-by: Evgeny Tsykunov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use QuantizedTensorBase

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Fix circular import

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Revert previous circular import fix

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Fix pytorch imports in common

Signed-off-by: Evgeny Tsykunov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Let quantizer know about the recipe

Signed-off-by: Evgeny Tsykunov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix imports

Signed-off-by: Evgeny Tsykunov <[email protected]>

---------

Signed-off-by: Evgeny Tsykunov <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix split_overlap_rs aggregate=True chunk offset calculation

Signed-off-by: Guyue Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add unit test for aggregate=True

Signed-off-by: Guyue Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix unit test

Signed-off-by: Guyue Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Guyue Huang <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…te (#1799)

* Use an empty torch tensor to indicate no fp8 information in extra_state

Signed-off-by: Peter St. John <[email protected]>

* Add huggingface from_pretrained / save_pretrained tests

Adds integration tests to ensure models containing TransformerLayer
objects can be saved and loaded using the from_pretrained and
save_pretrained methods.

Signed-off-by: Peter St. John <[email protected]>

---------

Signed-off-by: Peter St. John <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…n (#1611)

* docs drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* a

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* Update docs/debug/1_getting_started.rst

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Paweł Gadziński <[email protected]>

* Update docs/debug/1_getting_started.rst

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Paweł Gadziński <[email protected]>

* fixes

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix imgs

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Paweł Gadziński <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
add docstring for CP

Signed-off-by: Charlene Yang <[email protected]>
* Add missing docs for C API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Grammar, typos, copy-paste errors

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* remove contiguous word

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Better wording

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* fix model parallel encoder to be properly sharded

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
fix saved_tensors

Signed-off-by: Pawel Gadzinski <[email protected]>
Fix incorrectly skipped test_quantize_dbias tests

Signed-off-by: Jeremy Berchtold <[email protected]>
Remove comm_gemm_overlap docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Build support for cuda 13

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix build for cudnn 8.9*; cuda 12.1

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* readd include

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…rity (#1811)

Make primitive names more granular for better disabling granularity

Signed-off-by: Jeremy Berchtold <[email protected]>
Document all recipes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…#1804)

Activation ops support fusing backward pass with quantize

Signed-off-by: Tim Moon <[email protected]>
* Fix env variable name in test.sh scripts to properly test pure-JAX implementations

Signed-off-by: Jeremy Berchtold <[email protected]>

* Update test scripts to use pure-JAX impl in encoder

test_custom_call_compute.py already uses pure-JAX impl as
reference so testing the pure-JAX impl against itself would be
redundant. The encoder tests have their own implementation so
testing the pure-JAX impl of primitives is still useful.

Signed-off-by: Jeremy Berchtold <[email protected]>

* Update qa/L0_jax_unittest/test.sh

Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: jberchtold-nvidia <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: jberchtold-nvidia <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
* Modify the test cases

Signed-off-by: Przemek Tredak <[email protected]>

* Make the tests reproducible on different machines

Signed-off-by: Przemek Tredak <[email protected]>

* Fixed the cache of the gamma_in_weight_dtype setting

Signed-off-by: Przemek Tredak <[email protected]>

* Reinstate the tests

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* More verbose code and comments

Signed-off-by: Przemek Tredak <[email protected]>

---------

Signed-off-by: Przemek Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* added conda installation

Signed-off-by: Santosh Bhavani <[email protected]>

* fix for pypi

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Santosh Bhavani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix single FW build with multi FW available

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Some fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* sug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…ion (#1822)

Update jax_scaled_masked_softmax to match TE kernel implementation

Signed-off-by: Jeremy Berchtold <[email protected]>
* fp8 gemm with direct quant

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
* removes unnecessary reshapes for FP8 GEMM

* use nn.jax.scaled_matmul

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
…needed (#1817)

* Linear op avoids saving input tensor if weight grad is not needed

Signed-off-by: Tim Moon <[email protected]>

* Linear op forward avoids producing quantized tensors with unnecessary usages

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Avoid unnecessary usages in fused linear ops

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
…#1813)

* Changed the Tensor allocation strategy

Signed-off-by: Przemek Tredak <[email protected]>

* Fixes

Signed-off-by: Przemek Tredak <[email protected]>

* Disable debug flag

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the double free error

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Przemek Tredak <[email protected]>

* Fixed pyTorch recipe extension

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Przemek Tredak <[email protected]>

* Fix

Signed-off-by: Przemek Tredak <[email protected]>

* Hide TensorAllocator and fix the usage in LayerNorm

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleaning

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Przemek Tredak <[email protected]>

* Fix permutation

Signed-off-by: Przemek Tredak <[email protected]>

---------

Signed-off-by: Przemek Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Support SWA in CP Ring Attn THD striped sharding

Signed-off-by: Hua Huang <[email protected]>

* Add some comments; move check to _FusedAttnCPWithP2PHelper.check_supported()

Signed-off-by: Hua Huang <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Remove unused check

Signed-off-by: Hua Huang <[email protected]>

---------

Signed-off-by: Hua Huang <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Quantizer update

Signed-off-by: Evgeny Tsykunov <[email protected]>

* Update import

Signed-off-by: Evgeny <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Introduce _update_weight_quantizers and _get_weight_tensors/_get_weight_quantizers

Signed-off-by: Evgeny <[email protected]>

* Add test

Signed-off-by: Evgeny <[email protected]>

* Move _quantizer to the QuantizedTensorBase

Signed-off-by: Evgeny <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix import

Signed-off-by: Evgeny Tsykunov <[email protected]>

---------

Signed-off-by: Evgeny Tsykunov <[email protected]>
Signed-off-by: Evgeny <[email protected]>
Co-authored-by: Evgeny Tsykunov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
…talled (#1834)

* Add warning for multi framework case

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Alp Dener <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Alp Dener <[email protected]>
cyanguwa and others added 12 commits July 17, 2025 20:28
* update cudnn-frontend to 1.13.0

Signed-off-by: Charlene Yang <[email protected]>

* disable 9.11 for a bug

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix selection logic

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…it test (#1967)

* set precision=HIGHEST for the ref_grouped_gemm impl in the unit test

Signed-off-by: Phuong Nguyen <[email protected]>


---------

Signed-off-by: Phuong Nguyen <[email protected]>
* enable cudnn norm tests

Signed-off-by: Phuong Nguyen <[email protected]>

* exclude tests on pre-Hopper

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Update tolerance of distributed layernorm MLP for FP8

Signed-off-by: Jeremy Berchtold <[email protected]>
#else
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If fp4e2m1 datatype defined on ROCm, HIP_PLATFORM_AMD guard is not needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have a fp4_e2m1 defined in rocm but not sure whether they are the same corresponding to NV fp4:

/opt/rocm/include/hip/amd_detail/amd_hip_fp4.h:264:struct __hip_fp4_e2m1 {

That's why I didn't enable it in transformer_engine/common/common.h for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean if it builds and we disable FP4 support elsewhere this method should not be really called with FP4 so we may just have code working correctly w/o extra guards

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to include <hip/hip_fp4.h>. It was fine compiling libtransformer_engine.so but failed in pytorch extension so compilation:

  c++ -MMD -MF /tmp/tmplbl2x2ld.build-temp/transformer_engine/pytorch/csrc/extensions/bias_hip.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/opt/rocm/include -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/common -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/common/include -I/workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/pytorch/csrc -I/opt/venv/lib/python3.10/site-packages/torch/include -I/opt/venv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/opt/venv/include -I/usr/include/python3.10 -c -c /workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/pytorch/csrc/extensions/bias_hip.cpp -o /tmp/tmplbl2x2ld.build-temp/transformer_engine/pytorch/csrc/extensions/bias_hip.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -O3 -fvisibility=hidden -g0 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=transformer_engine_torch -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++17
  In file included from /opt/rocm/include/hip/amd_detail/amd_hip_fp4.h:31,
                   from /opt/rocm/include/hip/hip_fp4.h:29,
                   from /workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/common/common_hip.h:28,
                   from /workspace/te_ifu_v2.6_ca7407e-6bbd03c/transformer_engine/pytorch/csrc/extensions/bias_hip.cpp:8:
  /opt/rocm/include/hip/amd_detail/amd_hip_ocp_types.h:87:2: error: #error "Only supported by HIPCC or GCC >= 13."
     87 | #error "Only supported by HIPCC or GCC >= 13."
        |  ^~~~~

Our CI image has c++ of version 11.4:

c++ (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

@@ -476,6 +476,8 @@ void rocm_norm_mxfp8_quantize(LaunchParams<ForwardKernelParams> &launch_params)
}
#endif

bool& use_zero_centered_gamma_in_weight_dtype();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should probably be guarded with #ifndef HIP_PLATFORM_AMD. Maybe move it to line 430..433 - before rocm_norm_mxfp8_quantize() which is ROCm addition

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks

@@ -222,6 +225,16 @@ const std::string &include_directory(bool required) {
}
#endif // __HIP_PLATFORM_AMD__

int cudart_version() {
auto get_version = []() -> int {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will it return on ROCm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't run. I just moved it inside the HIP_PLATFORM_AMD

@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it different form upstream?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to make it the same as upstream but run into several compilation failures resulting from hipify "cuda_runtime.h" --> "hip/hip_runtime.h". In fact this cuda_runtime.h is under transformer_engine/common/util. I guess it's a bad naming.

I added a manual guard instead

#else
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean if it builds and we disable FP4 support elsewhere this method should not be really called with FP4 so we may just have code working correctly w/o extra guards

// break;
#ifndef USE_ROCM
case xla::ffi::DataType::F8E8M0FNU:
return DType::kFloat8E8M0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TE supports this datatype. Why is it guarded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Thanks

@functools.lru_cache(maxsize=None)
def is_fp8_gemm_with_all_layouts_supported() -> bool:
"""Return True if using Blackwell architecture, False otherwise."""
compute_capability = get_device_compute_capability()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need ROCm path here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added guard. Thanks

# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since .toml does not support conditions, need update jax[cuda12] with just jax, or make it ROCm specific and request jax[rocm]

Copy link
Collaborator Author

@wangye805 wangye805 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks

In fact, will this requires take effect since we are using "--no-build-isolation"?

@alextmagro
Copy link
Contributor

LGTM -- only covered common dir and cpp tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.