55from dataclasses import dataclass
66from typing import List , Optional
77
8+ import pytest
89import torch
910from tqdm import tqdm
1011
3031from megatron .core .models .gpt .gpt_model import GPTModel
3132from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
3233from megatron .core .transformer .transformer_config import TransformerConfig
34+ from megatron .core .utils import is_fa_min_version
3335from tests .unit_tests .test_utilities import Utils
3436
3537DynamicInferenceContext .ROUNDER = 4 # decreased from 64 for unit tests.
@@ -310,6 +312,9 @@ def setup_method(self, method):
310312 def teardown_method (self , method ):
311313 Utils .destroy_model_parallel ()
312314
315+ @pytest .mark .skipif (
316+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
317+ )
313318 def test_simple (self ) -> None :
314319 """Simple test that runs without errors, and validates output."""
315320
@@ -336,6 +341,9 @@ def test_simple(self) -> None:
336341 for request , expected_output in zip (env .requests , expected_outputs ):
337342 assert request .output == expected_output
338343
344+ @pytest .mark .skipif (
345+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
346+ )
339347 def test_overflow_factor (self ) -> None :
340348 """Test overflow factor arg."""
341349
@@ -350,6 +358,9 @@ def test_overflow_factor(self) -> None:
350358 assert env .engine .context .max_requests == 1120
351359 assert env .engine .context .max_tokens == 1120
352360
361+ @pytest .mark .skipif (
362+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
363+ )
353364 def test_request_overflow (self ) -> None :
354365 """Test request overflow."""
355366 try :
@@ -358,6 +369,9 @@ def test_request_overflow(self) -> None:
358369 return
359370 raise Exception ("failed." )
360371
372+ @pytest .mark .skipif (
373+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
374+ )
361375 def test_token_overflow (self ) -> None :
362376 """Test token overflow."""
363377 try :
@@ -366,6 +380,9 @@ def test_token_overflow(self) -> None:
366380 return
367381 raise Exception ("failed." )
368382
383+ @pytest .mark .skipif (
384+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
385+ )
369386 def test_chunk_overflow (self ) -> None :
370387 """Test chunk overflow."""
371388 env = self ._build_test_env (TestConfig ())
@@ -378,10 +395,16 @@ def test_chunk_overflow(self) -> None:
378395 return
379396 raise Exception ("failed." )
380397
398+ @pytest .mark .skipif (
399+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
400+ )
381401 def test_multi_add (self ) -> None :
382402 """Test adding multiple requests simultaneously."""
383403 self ._run_test (num_gap_steps = 0 )
384404
405+ @pytest .mark .skipif (
406+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
407+ )
385408 def test_fixed_output_lengths (self ) -> None :
386409 """Test generating a fixed number of output tokens."""
387410 self ._run_test (use_fixed_output_lengths = True )
0 commit comments