Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/dynamic_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ def dynamic_metadata(
]
optional_dependencies["lmp"].extend(find_libpython_requires)
optional_dependencies["ipi"].extend(find_libpython_requires)
torch_static_requirement = optional_dependencies.pop("torch", ())
return {
**optional_dependencies,
**get_tf_requirement(tf_version),
**get_pt_requirement(pt_version),
**get_pt_requirement(
pt_version, static_requirement=tuple(torch_static_requirement)
),
}
10 changes: 9 additions & 1 deletion backend/find_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def find_pytorch() -> tuple[str | None, list[str]]:


@lru_cache
def get_pt_requirement(pt_version: str = "") -> dict:
def get_pt_requirement(
pt_version: str = "",
static_requirement: tuple[str] | None = None,
) -> dict:
"""Get PyTorch requirement when PT is not installed.

If pt_version is not given and the environment variable `PYTORCH_VERSION` is set, use it as the requirement.
Expand All @@ -99,6 +102,8 @@ def get_pt_requirement(pt_version: str = "") -> dict:
----------
pt_version : str, optional
PT version
static_requirement : tuple[str] or None, optional
Static requirements

Returns
-------
Expand All @@ -125,6 +130,8 @@ def get_pt_requirement(pt_version: str = "") -> dict:
mpi_requirement = ["mpich"]
else:
mpi_requirement = []
if static_requirement is None:
static_requirement = ()

return {
"torch": [
Expand All @@ -138,6 +145,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
else "torch>=2.1.0",
*mpi_requirement,
*cibw_requirement,
*static_requirement,
],
}

Expand Down
1 change: 0 additions & 1 deletion deepmd/dpmodel/modifier/base_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def serialize(self) -> dict:
dict
The serialized data
"""
pass

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
Expand Down
18 changes: 6 additions & 12 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import io
import json
import logging
import os
import pickle
from pathlib import (
Path,
)
Expand Down Expand Up @@ -401,17 +401,11 @@ def freeze(
model.eval()
model = torch.jit.script(model)

dm_output = "data_modifier.pth"
extra_files = {dm_output: ""}
if tester.modifier is not None:
dm = tester.modifier
dm.eval()
buffer = io.BytesIO()
torch.jit.save(
torch.jit.script(dm),
buffer,
)
extra_files = {dm_output: buffer.getvalue()}
extra_files = {"modifier_data": ""}
dm = tester.modifier
if dm is not None:
bytes_data = pickle.dumps(dm.serialize())
extra_files = {"modifier_data": bytes_data}
torch.jit.save(
model,
output,
Expand Down
18 changes: 11 additions & 7 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import io
import json
import logging
import pickle
from collections.abc import (
Callable,
)
Expand Down Expand Up @@ -49,6 +49,9 @@
from deepmd.pt.model.network.network import (
TypeEmbedNetConsistent,
)
from deepmd.pt.modifier import (
BaseModifier,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
Expand Down Expand Up @@ -172,19 +175,20 @@ def __init__(
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
extra_files = {"data_modifier.pth": ""}
extra_files = {"modifier_data": ""}
model = torch.jit.load(
model_file, map_location=env.DEVICE, _extra_files=extra_files
)
modifier = None
# Load modifier if it exists in extra_files
if len(extra_files["data_modifier.pth"]) > 0:
# Create a file-like object from the in-memory data
modifier_data = extra_files["data_modifier.pth"]
if len(extra_files["modifier_data"]) > 0:
modifier_data = extra_files["modifier_data"]
if isinstance(modifier_data, bytes):
modifier_data = io.BytesIO(modifier_data)
modifier_data = pickle.loads(modifier_data)
# Load the modifier directly from the file-like object
modifier = torch.jit.load(modifier_data, map_location=env.DEVICE)
modifier = BaseModifier.get_class_by_type(
modifier_data["type"]
).deserialize(modifier_data)
self.dp = ModelWrapper(model, modifier=modifier)
self.modifier = modifier
model_def_script = self.dp.model["Default"].get_model_def_script()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/modifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from .base_modifier import (
BaseModifier,
)
from .dipole_charge import (
DipoleChargeModifier,
)

__all__ = [
"BaseModifier",
"DipoleChargeModifier",
"get_data_modifier",
]

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/modifier/base_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def serialize(self) -> dict:
data = {
"@class": "Modifier",
"type": self.modifier_type,
"use_cache": self.use_cache,
"@version": 3,
}
return data
Expand Down
Loading
Loading