Skip to content

Commit b5a5c31

Browse files
committed
fixed formatting; fixed issue when attn_name not specified
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent 4118652 commit b5a5c31

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55
from aiu_fms_testing_utils.utils.aiu_setup import dprint
6-
import os
76
from aiu_fms_testing_utils._version import version_tuple
7+
import os
88

99

1010
class LogitsExtractorHook(
@@ -261,7 +261,7 @@ def extract_validation_information(
261261
**extra_kwargs,
262262
):
263263
attention_specific_kwargs = {}
264-
if "paged" in extra_kwargs["attn_name"]:
264+
if "paged" in extra_kwargs.get("attn_name", "sdpa"):
265265
from aiu_fms_testing_utils.utils.paged import generate
266266
else:
267267
# TODO: Add a unified generation dependent on attn_type

tests/testing/test_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from aiu_fms_testing_utils._version import version_tuple
1212
from fms.models import get_model
1313
from fms.utils.generation import pad_input_ids
14-
import torch
1514
from pathlib import Path
15+
import torch
1616

1717

1818
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)