Skip to content

Commit 9d35d5d

Browse files
committed
Github Actions setup
1 parent cd8c9d2 commit 9d35d5d

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

.github/workflows/test-export.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
name: Test Few-Shot Model Export
22

33
on:
4+
workflow_dispatch:
5+
46
push:
57
branches: [ main ]
68
pull_request:

testing/test1.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
1+
# testing/test1.py
2+
13
import torch
4+
import torchvision.transforms as T
5+
import pprint
6+
7+
# Trust the transforms used in the model export
8+
torch.serialization.add_safe_globals({
9+
'torchvision.transforms.transforms.Compose': T.Compose,
10+
'torchvision.transforms.transforms.Resize': T.Resize,
11+
'torchvision.transforms.transforms.ToTensor': T.ToTensor
12+
})
213

314
def test_model_export(model_path):
4-
print(f"Loading model from: {model_path}")
5-
config = torch.load(model_path, map_location='cpu')
15+
print(f"\nLoading model from: {model_path}")
16+
17+
config = torch.load(model_path, map_location='cpu', weights_only=False)
18+
19+
print("\n--- Exported Model Content ---")
620

721
for key, value in config.items():
8-
if torch.is_tensor(value):
9-
print(f"{key}: Tensor with shape {tuple(value.shape)}")
22+
if isinstance(value, torch.Tensor):
23+
print(f"{key}: Tensor shape = {tuple(value.shape)}")
1024
else:
11-
print(f"{key}: {value}")
25+
print(f"{key}:")
26+
pprint.pprint(value)
1227

13-
# Run
14-
test_model_export("export/fewshot_model.pt")
28+
if __name__ == "__main__":
29+
test_model_export("export/fewshot_model.pt")

0 commit comments

Comments
 (0)