Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 64 additions & 12 deletions cpp/kernels/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -466,20 +466,53 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
#define MMAS_N_PER_MASK 2

__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
,
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
#endif
)
{
uint32_t const idxInQuad = laneId() % 4;
uint32_t const idxQuad = laneId() / 4;
// Packed mask is aligned with 32 bits (2 uint16_t).
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
constexpr uint64_t fullMask = ~uint64_t{0};
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange));
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
#else
constexpr bool ctaNeedBegMask = false;
bool const ctaNeedSpecDecMask = true;
int32_t const tok0NbMaskOut = -2147483648;
#endif
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;

if (!needMask)
{
return;
}
#pragma unroll
for (uint32_t m = 0; m < acc.rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8) / headGrpSize;
uint32_t const tokenRow = min(idxQTokInCta, actualQSeqLen - 1);
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta);
uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask);
#else
uint64_t const begMask = fullMask;
#endif

#pragma unroll
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
{
Expand All @@ -491,12 +524,15 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
? 0u
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
uint32_t packedMask = 0u;
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
reinterpret_cast<uint16_t*>(&packedMask)[0]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
reinterpret_cast<uint16_t*>(&packedMask)[1]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
uint32_t packedMask = ~uint32_t{0};
if (ctaNeedSpecDecMask)
{
reinterpret_cast<uint16_t*>(&packedMask)[0]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
reinterpret_cast<uint16_t*>(&packedMask)[1]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
}
#pragma unroll
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
{
Expand All @@ -510,7 +546,11 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
bool const maskFlag = col + actualQSeqLen < nbValidCols
? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;

bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true;

acc(m, n)(i, j)
= maskFlag && begMaskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
}
}
}
Expand Down Expand Up @@ -1611,8 +1651,14 @@ CUBIN_EXPORT __global__
#endif

uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
#if SLIDING_WINDOW
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);

#elif SLIDING_WINDOW
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
assert(!SPEC_DEC || !rtIsReallySliding);
uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0;
#else
constexpr bool rtIsReallySliding = false;
Expand All @@ -1626,7 +1672,9 @@ CUBIN_EXPORT __global__
#endif

uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
#if SPEC_DEC
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
#elif SPEC_DEC
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
#endif

Expand Down Expand Up @@ -1912,8 +1960,12 @@ CUBIN_EXPORT __global__
if (seqIter >= nbSeqItersWithoutMask)
{
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
applyMaskFromInput(
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
,
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
#endif
);
}
#else
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
Expand Down
1 change: 1 addition & 0 deletions jenkins/L0_Test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,7 @@ def launchTestJobs(pipeline, testFilter)

x86SlurmTestConfigs = [
"DGX_H100-2_GPUs-PyTorch-Others-1": ["dgx-h100-x2-oci", "l0_dgx_h100", 1, 1, 2],
"DGX_H100-2_GPUs-PyTorch-GptOss-1": ["dgx-h100-x2-oci", "l0_dgx_h100", 1, 1, 2],
"DGX_H100-2_GPUs-PyTorch-Ray-1": ["dgx-h100-x2-oci", "l0_dgx_h100", 1, 1, 2],
"DGX_H100-4_GPUs-PyTorch-DeepSeek-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-4_GPUs-PyTorch-GptOss-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
Expand Down
11 changes: 7 additions & 4 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def run(
self.spec_decoding_generation_lengths,
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
]
if get_sm_version() >= 100:
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
spec_decoding_tensor_params.append(
self.spec_decoding_bl_tree_mask_offset)
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
Expand Down Expand Up @@ -1219,12 +1219,12 @@ def update_spec_dec_param(

# spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree.
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
get_sm_version() < 100 or get_sm_version() == 120)
not self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()))

self.is_spec_dec_tree = spec_tree_manager is not None
self.is_spec_dec_dynamic_tree = spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree

if get_sm_version() >= 100 and get_sm_version() != 120:
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."

Expand Down Expand Up @@ -1260,7 +1260,7 @@ def update_spec_dec_param(
device='cuda',
)

if get_sm_version() >= 100:
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
self.spec_decoding_param_prepare_for_blackwell()
else:
self.spec_decoding_bl_tree_mask_offset = None
Expand Down Expand Up @@ -1371,6 +1371,9 @@ def generate_spec_decoding_generation_length(self, max_draft_len):
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
spec_decoding_generation_length, non_blocking=True)

def is_sm_version_trtllm_gen_kernel(self, sm):
return not (sm < 100 or sm in [120, 121])


class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):

Expand Down
88 changes: 85 additions & 3 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4248,14 +4248,16 @@ def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
def test_eagle3(self, moe_backend, one_model, overlap_scheduler, mocker):
def test_eagle3_4gpus(self, moe_backend, one_model, overlap_scheduler,
mocker):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")

if get_sm_version() == 90 and moe_backend == "CUTLASS":
if get_sm_version() == 90:
pytest.skip(
"https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue")
"https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue for only TP=4"
)

MAX_OUTPUT_LEN = 128179
MAX_INPUT_LEN = 32768
Expand Down Expand Up @@ -4318,6 +4320,86 @@ def test_eagle3(self, moe_backend, one_model, overlap_scheduler, mocker):
sampling_params=sampling_params,
extra_evaluator_kwargs=extra_evaluator_kwargs)

@pytest.mark.skip_less_device(2)
@pytest.mark.timeout(14400)
@pytest.mark.parametrize("overlap_scheduler", [True, False],
ids=["overlap_scheduler", "no_overlap_scheduler"])
@pytest.mark.parametrize("one_model", [True, False],
ids=["one_model", "two_model"])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
def test_eagle3_2gpus(self, moe_backend, one_model, overlap_scheduler,
mocker):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")

MAX_OUTPUT_LEN = 128179
MAX_INPUT_LEN = 32768

mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
{"scores_filter": "exact_match,flexible-extract"})

mocker.patch.object(GPQADiamond, "MAX_OUTPUT_LEN", MAX_OUTPUT_LEN)
mocker.patch.object(GPQADiamond, "MAX_INPUT_LEN", MAX_INPUT_LEN)

# https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
pytorch_config = dict(
max_batch_size=8,
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(max_batch_size=8))
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
dtype="auto")

eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model)

max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
llm = LLM(self.MODEL_PATH,
tensor_parallel_size=2,
pipeline_parallel_size=1,
moe_expert_parallel_size=1,
kv_cache_config=kv_cache_config,
max_seq_len=max_seq_len,
speculative_config=spec_config,
**pytorch_config,
enable_attention_dp=False,
moe_config=MoeConfig(backend=moe_backend))

with llm:
model_name = "GPT-OSS/120B-MXFP4"

# GSM8K
task = GSM8K(model_name)
task.evaluate(llm,
extra_evaluator_kwargs=self.extra_evaluator_kwargs)

# GPQA Medium Reasoning
task = GPQADiamond(model_name)

chat_template_kwargs = dict(reasoning_effort="medium")
extra_evaluator_kwargs = {
**self.extra_evaluator_kwargs, "chat_template_kwargs":
chat_template_kwargs
}

sampling_params = SamplingParams(
temperature=1.0,
top_p=1.0,
max_tokens=MAX_OUTPUT_LEN,
truncate_prompt_tokens=MAX_INPUT_LEN)

task.evaluate(llm,
sampling_params=sampling_params,
extra_evaluator_kwargs=extra_evaluator_kwargs)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["GB200"])
@pytest.mark.parametrize(
Expand Down
24 changes: 12 additions & 12 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -566,18 +566,18 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-no_overlap_scheduler]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
Expand Down
24 changes: 12 additions & 12 deletions tests/integration/test_lists/qa/llm_function_core_sanity.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-no_overlap_scheduler]
accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
Expand Down
Loading