@@ -1435,6 +1435,8 @@ def test_sdpa_choice(config):
1435
1435
pytest .skip ("Gemma 2 doesn't support SDPA" )
1436
1436
1437
1437
torch .set_default_dtype (torch .float16 )
1438
+ config ["n_layer" ] = 1
1439
+ config = config_module .Config (** config )
1438
1440
enable_gqa = config ["n_query_groups" ] < config ["n_head" ]
1439
1441
1440
1442
def assert_sdpa_backend (original_fn , query , k_and_v , mask , return_scores ):
@@ -1457,9 +1459,6 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1457
1459
raise NotImplementedError
1458
1460
return original_fn (query , k_and_v , mask , return_scores )
1459
1461
1460
- config ["n_layer" ] = 1
1461
- config = config_module .Config (** config )
1462
-
1463
1462
try :
1464
1463
with torch .device ("cuda" ):
1465
1464
model = GPT (config )
@@ -1488,6 +1487,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1488
1487
@torch .inference_mode ()
1489
1488
def test_sdpa_choice_kv_cache (config ):
1490
1489
torch .set_default_dtype (torch .float16 )
1490
+ config ["n_layer" ] = 1
1491
+ config = config_module .Config (** config )
1491
1492
enable_gqa = config ["n_query_groups" ] < config ["n_head" ]
1492
1493
1493
1494
def assert_sdpa_backend (original_fn , query , k_and_v , mask , return_scores ):
@@ -1509,9 +1510,6 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1509
1510
raise NotImplementedError
1510
1511
return original_fn (query , k_and_v , mask , return_scores )
1511
1512
1512
- config ["n_layer" ] = 1
1513
- config = config_module .Config (** config )
1514
-
1515
1513
try :
1516
1514
with torch .device ("cuda" ):
1517
1515
model = GPT (config )
0 commit comments