Skip to content

Commit 2099bfd

Browse files
refactor: add type casting for return values in read_transactions_from_spm
1 parent 873fab1 commit 2099bfd

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

gsppy/cli.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def read_transactions_from_spm(file_path: str) -> List[List[str]]:
216216
try:
217217
from gsppy.utils import read_transactions_from_spm as read_spm
218218

219-
return read_spm(file_path, return_mappings=False)
219+
return cast(List[List[str]], read_spm(file_path, return_mappings=False))
220220
except Exception as e:
221221
msg = f"Error reading transaction data from SPM file '{file_path}': {e}"
222222
logging.error(msg)
@@ -363,15 +363,15 @@ def _load_dataframe_format(
363363
) -> Union[List[List[str]], List[List[Tuple[str, float]]]]:
364364
"""
365365
Load transactions from DataFrame formats (Parquet/Arrow).
366-
366+
367367
Parameters:
368368
file_path: Path to the file
369369
file_extension: File extension (lowercase)
370370
transaction_col: Transaction ID column name
371371
item_col: Item column name
372372
timestamp_col: Timestamp column name
373373
sequence_col: Sequence column name
374-
374+
375375
Returns:
376376
Loaded transactions
377377
"""
@@ -405,7 +405,7 @@ def _load_transactions_by_format(
405405
) -> Union[List[List[str]], List[List[Tuple[str, float]]]]:
406406
"""
407407
Load transactions based on specified format.
408-
408+
409409
Parameters:
410410
file_path: Path to the file
411411
file_format: Format string (lowercase)
@@ -415,10 +415,10 @@ def _load_transactions_by_format(
415415
item_col: Item column name
416416
timestamp_col: Timestamp column name
417417
sequence_col: Sequence column name
418-
418+
419419
Returns:
420420
Loaded transactions
421-
421+
422422
Raises:
423423
ValueError: If format is unknown
424424
"""
@@ -429,9 +429,7 @@ def _load_transactions_by_format(
429429
elif file_format == FileFormat.CSV.value:
430430
return read_transactions_from_csv(file_path)
431431
elif file_format in (FileFormat.PARQUET.value, FileFormat.ARROW.value):
432-
return _load_dataframe_format(
433-
file_path, file_extension, transaction_col, item_col, timestamp_col, sequence_col
434-
)
432+
return _load_dataframe_format(file_path, file_extension, transaction_col, item_col, timestamp_col, sequence_col)
435433
elif file_format == FileFormat.AUTO.value:
436434
# Auto-detect format
437435
if is_dataframe_format:

tests/test_spm_format.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import os
1212
import tempfile
13-
from typing import Generator
13+
from typing import Generator, cast
1414

1515
import pytest
1616

@@ -110,7 +110,8 @@ def test_basic_parsing(self, simple_spm_file: str):
110110

111111
def test_basic_parsing_with_mappings(self, simple_spm_file: str):
112112
"""Test SPM parsing with token mappings."""
113-
transactions, str_to_int, int_to_str = read_transactions_from_spm(simple_spm_file, return_mappings=True)
113+
result = cast(tuple, read_transactions_from_spm(simple_spm_file, return_mappings=True))
114+
transactions, str_to_int, int_to_str = result
114115

115116
assert len(transactions) == 3
116117
assert len(str_to_int) == 6 # Unique tokens: 1, 2, 3, 4, 5, 6
@@ -261,7 +262,7 @@ def complex_spm_file(self) -> Generator[str, None, None]:
261262

262263
def test_complex_parsing(self, complex_spm_file: str):
263264
"""Test parsing complex SPM file."""
264-
transactions = read_transactions_from_spm(complex_spm_file)
265+
transactions = cast(list, read_transactions_from_spm(complex_spm_file))
265266

266267
assert len(transactions) == 4
267268
assert transactions[0] == ["1", "2", "3", "1", "4", "5"]
@@ -271,7 +272,8 @@ def test_complex_parsing(self, complex_spm_file: str):
271272

272273
def test_complex_with_mappings(self, complex_spm_file: str):
273274
"""Test complex file with mappings."""
274-
transactions, str_to_int, _ = read_transactions_from_spm(complex_spm_file, return_mappings=True)
275+
result = cast(tuple, read_transactions_from_spm(complex_spm_file, return_mappings=True))
276+
transactions, str_to_int, _ = result
275277

276278
assert len(transactions) == 4
277279
# Unique tokens: 1, 2, 3, 4, 5

0 commit comments

Comments
 (0)