Skip to content

Commit b12dc58

Browse files
committed
Fixes from PR comments, unit tests
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent f4ec836 commit b12dc58

File tree

4 files changed

+6
-3
lines changed

4 files changed

+6
-3
lines changed

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from fms_mo.prep import available_packages
3232

3333
# pylint: disable=not-callable
34-
# torch.nn.functional.scaled_dot_product_attention not recognized as callable
34+
# torch.nn.functional.linear not recognized as callable
3535
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3636

3737
# Gated torchao imports for FP8 implementation

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def scaled_paged_attn_store(
9999
Scales key and value tensors, and stores them to the paged KV cache
100100
using the same schema as vLLM.
101101
"""
102+
print("Should never hit")
102103
result_key_cache = key_cache.clone()
103104
result_value_cache = value_cache.clone()
104105
for seq_i, slot_mapping_seq in enumerate(slot_mapping):
@@ -150,6 +151,7 @@ def scaled_paged_attn_compute(
150151
Implements a CPU fallback to run the kernel that has been confirmed
151152
to match the vLLM fused kernel.
152153
"""
154+
print("Should never hit")
153155
# torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
154156
output = torch.zeros_like(query)
155157
num_query_heads = query.shape[2]

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ dependencies = [
3030
"datasets>=3.0.0,<4.0",
3131
"pandas",
3232
"safetensors",
33-
"pkginfo>1.10",
34-
"torchao"
33+
"pkginfo>1.10"
3534
]
3635

3736
[project.optional-dependencies]

tox.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ deps =
3333
pytest
3434
pylint>=2.16.2,<4.0
3535
pylint-pydantic
36+
ibm-fms
37+
torchao
3638
commands =
3739
{basepython} -m pylint --load-plugins pylint_pydantic fms_mo/ tests/
3840

0 commit comments

Comments
 (0)