forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_dlrm_mar.py
More file actions
40 lines (28 loc) · 848 Bytes
/
create_dlrm_mar.py
File metadata and controls
40 lines (28 loc) · 848 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
This script creates a DLRM model and packs it into a TorchServe mar file
"""
import os
import torch
from dlrm_factory import DLRMFactory
MODEL_PT_FILE = "dlrm.pt"
def create_pt_file(output_file: str) -> None:
module = DLRMFactory()
torch.save(module.cpu().state_dict(), output_file)
def main():
print(f"Creating DLRM model and saving state_dict to {MODEL_PT_FILE}")
create_pt_file(MODEL_PT_FILE)
cmd = [
"torch-model-archiver",
"--model-name dlrm",
"--version 1.0",
f"--serialized-file {MODEL_PT_FILE}",
"--model-file dlrm_factory.py",
"--extra-files dlrm_model_config.py",
"--handler dlrm_handler.py",
"--force",
]
print("Archiving model into dlrm.mar")
os.system(" ".join(cmd))
print("Done")
if __name__ == "__main__":
main()