Skip to content

Commit 94324c8

Browse files
fix workflow failures
Signed-off-by: Takeshi Yoshimura <tyos@jp.ibm.com>
1 parent bb18069 commit 94324c8

File tree

17 files changed

+181
-72
lines changed

17 files changed

+181
-72
lines changed

.github/workflows/test-paddle.yaml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,10 @@ jobs:
4242
cd tests
4343
LIBDIR=`python3 -c "import os; os.chdir('/tmp'); import fastsafetensors; print(os.path.dirname(fastsafetensors.__file__))"`
4444
mkdir -p /tmp/pytest-log
45-
export TEST_FASTSAFETENSORS_FRAMEWORK=torch
46-
COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1
47-
COVERAGE_FILE=.coverage_1 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/pytest-log/1.log 2>&1 &
48-
COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/pytest-log/2.log 2>&1
4945
export TEST_FASTSAFETENSORS_FRAMEWORK=paddle
50-
COVERAGE_FILE=.coverage_3 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/3.log 2>&1
51-
COVERAGE_FILE=.coverage_4 WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 0 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/pytest-log/4.log 2>&1 & \
52-
COVERAGE_FILE=.coverage_5 WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 1 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/pytest-log/5.log 2>&1 && \
46+
COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1
47+
COVERAGE_FILE=.coverage_1 WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 0 tests/test_multi.py --cov=$(LIBDIR) -s tests/test_multi.py > /tmp/pytest-log/1.log 2>&1 & \
48+
COVERAGE_FILE=.coverage_2 WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 1 tests/test_multi.py --cov=$(LIBDIR) -s tests/test_multi.py > /tmp/pytest-log/2.log 2>&1 && \
5349
coverage combine .coverage_*
5450
coverage html
5551
mv htmlcov /tmp/pytest-log

.github/workflows/test-torch.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ jobs:
5858
mkdir -p /tmp/pytest-log
5959
export TEST_FASTSAFETENSORS_FRAMEWORK=pytorch
6060
COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1
61-
COVERAGE_FILE=.coverage_1 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/pytest-log/1.log 2>&1 &
62-
COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/pytest-log/2.log 2>&1
61+
COVERAGE_FILE=.coverage_1 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 tests/test_multi.py --cov=$(LIBDIR) -s tests/test_multi.py > /tmp/pytest-log/1.log 2>&1 &
62+
COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 tests/test_multi.py --cov=$(LIBDIR) -s tests/test_multi.py > /tmp/pytest-log/2.log 2>&1
6363
coverage combine .coverage_*
6464
coverage html
6565
mv htmlcov /tmp/pytest-log

examples/run_parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# PIDS+=$($!)
88
# wait ${PIDS[@]}
99

10+
1011
def run_torch():
1112
import torch
1213
import torch.distributed as dist
@@ -17,6 +18,7 @@ def run_torch():
1718
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1819
return pg, device
1920

21+
2022
def run_paddle():
2123
import paddle
2224
import paddle.distributed as dist
@@ -27,13 +29,15 @@ def run_paddle():
2729
device = "gpu" if paddle.device.cuda.device_count() else "cpu"
2830
return pg, device
2931

32+
3033
runs = {
3134
"torch": run_torch,
3235
"paddle": run_paddle,
3336
}
3437

3538
if __name__ == "__main__":
3639
import sys
40+
3741
from fastsafetensors import SafeTensorsFileLoader
3842

3943
framework = "torch"

examples/run_single.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,41 @@
11
#!/usr/bin/env python3
22

3+
34
def run_torch():
45
import torch
6+
57
device = "cuda:0" if torch.cuda.is_available() else "cpu"
68
return device
79

10+
811
def run_paddle():
912
import paddle
13+
1014
device = "gpu" if paddle.device.cuda.device_count() else "cpu"
1115
return device
1216

17+
1318
runs = {
1419
"torch": run_torch,
1520
"paddle": run_paddle,
1621
}
1722

1823
if __name__ == "__main__":
1924
import sys
20-
from fastsafetensors import fastsafe_open
25+
2126
from fastsafetensors import cpp as fstcpp
27+
from fastsafetensors import fastsafe_open
2228

2329
framework = "torch"
2430
if len(sys.argv) > 1:
2531
framework = sys.argv[1]
2632

2733
device = runs[framework]()
28-
with fastsafe_open(["a.safetensors", "b.safetensors"], device=device, nogds=not fstcpp.is_cufile_found(), framework=framework) as f:
34+
with fastsafe_open(
35+
["a.safetensors", "b.safetensors"],
36+
device=device,
37+
nogds=not fstcpp.is_cufile_found(),
38+
framework=framework,
39+
) as f:
2940
print(f"a0: {f.get_tensor(name='a0')}")
3041
print(f"b0: {f.get_tensor(name='b0')}")

fastsafetensors/common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import cpp as fstcpp
1212
from .dlpack import from_cuda_buffer
13-
from .frameworks import TensorBase, FrameworkOpBase
13+
from .frameworks import FrameworkOpBase, TensorBase
1414
from .st_types import Device, DType
1515

1616

@@ -80,7 +80,9 @@ def __init__(
8080
)
8181

8282
@classmethod
83-
def from_buffer(self, buf: int, buffer_len: int, filename: str, framework: FrameworkOpBase):
83+
def from_buffer(
84+
self, buf: int, buffer_len: int, filename: str, framework: FrameworkOpBase
85+
):
8486
if buffer_len < 8:
8587
raise Exception(
8688
f"from_buffer: HeaderTooSmall, filename={filename}, buffer_len={buffer_len}"
@@ -173,7 +175,9 @@ def get_tensors(
173175
t2 = t2.view(t.dtype)
174176

175177
if dtype != DType.AUTO and dtype != t.dtype:
176-
if self.framework.get_dtype_size(dtype) > self.framework.get_dtype_size(t.dtype):
178+
if self.framework.get_dtype_size(dtype) > self.framework.get_dtype_size(
179+
t.dtype
180+
):
177181
raise Exception(
178182
f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})"
179183
)

fastsafetensors/copier/gds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .. import cpp as fstcpp
77
from ..common import SafeTensorsMetadata
8-
from ..frameworks import TensorBase, FrameworkOpBase
8+
from ..frameworks import FrameworkOpBase, TensorBase
99
from ..st_types import Device, DeviceType, DType
1010

1111

fastsafetensors/copier/nogds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .. import cpp as fstcpp
88
from ..common import SafeTensorsMetadata
9-
from ..frameworks import TensorBase, FrameworkOpBase
9+
from ..frameworks import FrameworkOpBase, TensorBase
1010
from ..st_types import Device, DType
1111

1212

fastsafetensors/file_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import OrderedDict
55
from typing import Dict, List, Optional, Tuple
66

7-
from .frameworks import ProcessGroupBase, TensorBase, FrameworkOpBase
7+
from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase
88
from .st_types import Device, DType
99
from .tensor_factory import LazyTensorFactory
1010

fastsafetensors/frameworks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def randn(self, s: tuple, dtype: DType) -> T:
163163
def support_fp8(self) -> bool:
164164
pass
165165

166+
166167
def get_framework_op(name: str) -> FrameworkOpBase:
167168
if name == "pt" or name == "pytorch" or name == "torch":
168169
from ._torch import TorchOp

fastsafetensors/frameworks/_paddle.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
try:
55
import paddle
66
import paddle.distributed as pdist
7-
from paddle.framework import core as paddle_core
87
from paddle.distributed.communication.group import Group
8+
from paddle.framework import core as paddle_core
99
except ImportError as e:
1010
raise ImportError(
1111
"could not import paddle, paddle_core, or numpy. Please install them."
@@ -38,11 +38,12 @@
3838
DType.F8_E4M3: DType.I8,
3939
}
4040

41-
if hasattr(paddle, 'float8_e5m2'):
41+
if hasattr(paddle, "float8_e5m2"):
4242
dtype_convert[DType.F8_E5M2] = paddle.float8_e5m2
43-
if hasattr(paddle, 'float8_e4m3fn'):
43+
if hasattr(paddle, "float8_e4m3fn"):
4444
dtype_convert[DType.F8_E4M3] = paddle.float8_e4m3fn
4545

46+
4647
@dataclass
4748
class PaddleTensor(TensorBase):
4849
real_tensor: paddle.Tensor
@@ -222,7 +223,9 @@ def as_workaround_dtype(self, dtype: DType) -> DType:
222223

223224
def get_process_group(self, pg: Optional[Any]) -> PaddleProcessGroup:
224225
if pg is not None and not isinstance(pg, Group):
225-
raise Exception("pg must be an instance of paddle.distributed.communication.group.Group")
226+
raise Exception(
227+
"pg must be an instance of paddle.distributed.communication.group.Group"
228+
)
226229
return PaddleProcessGroup(pg)
227230

228231
# for testing
@@ -232,7 +235,11 @@ def is_equal(self, wrapped: PaddleTensor, real: Any) -> bool:
232235
raise Exception("real is not paddle.Tensor")
233236

234237
def randn(self, s: tuple, device: Device, dtype: DType) -> PaddleTensor:
235-
return PaddleTensor(device, dtype, paddle.randn(s, dtype=dtype_convert[dtype]).to(device=device.as_str()))
238+
return PaddleTensor(
239+
device,
240+
dtype,
241+
paddle.randn(s, dtype=dtype_convert[dtype]).to(device=device.as_str()),
242+
)
236243

237244
def support_fp8(self) -> bool:
238-
return DType.F8_E5M2 in dtype_convert
245+
return DType.F8_E5M2 in dtype_convert

0 commit comments

Comments
 (0)