Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
self.type_dict = type_dict
else:
self.type_dict = dict(
zip(self.dp.get_type_map(), range(self.dp.get_ntypes()))
zip(self.dp.get_type_map(), range(self.dp.get_ntypes()), strict=True)
)

def calculate(
Expand Down
8 changes: 5 additions & 3 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_model_sels(self) -> list[int | list[int]]:
def _sort_rcuts_sels(self) -> tuple[tuple[Array, Array], list[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
zipped = sorted(
zip(self.get_model_rcuts(), self.get_model_nsels()),
zip(self.get_model_rcuts(), self.get_model_nsels(), strict=True),
key=lambda x: (x[1], x[0]),
)
return [p[0] for p in zipped], [p[1] for p in zipped]
Expand Down Expand Up @@ -235,12 +235,14 @@ def forward_atomic(
)
raw_nlists = [
nlists[get_multiple_nlist_key(rcut, sel)]
for rcut, sel in zip(self.get_model_rcuts(), self.get_model_nsels())
for rcut, sel in zip(
self.get_model_rcuts(), self.get_model_nsels(), strict=True
)
]
nlists_ = [
nl if mt else nlist_distinguish_types(nl, extended_atype, sel)
for mt, nl, sel in zip(
self.mixed_types_list, raw_nlists, self.get_model_sels()
self.mixed_types_list, raw_nlists, self.get_model_sels(), strict=True
)
]
ener_list = []
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
start_idx = np.cumsum(np.pad(hybrid_sel, (1, 0), "constant"))[:-1]
end_idx = start_idx + np.array(sub_sel)
cut_idx = np.concatenate(
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
[range(ss, ee) for ss, ee in zip(start_idx, end_idx, strict=True)]
)
nlist_cut_idx.append(cut_idx)
self.nlist_cut_idx = nlist_cut_idx
Expand Down Expand Up @@ -310,7 +310,7 @@ def call(
)
else:
nl_distinguish_types = None
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx, strict=True):
# cut the nlist to the correct length
if self.mixed_types() == descrpt.mixed_types():
nl = xp.take(nlist, nci, axis=2)
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def eval(
zip(
[x.name for x in request_defs],
out,
strict=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def build_multiple_neighbor_list(
rr = xp.where(nlist_mask, xp.full_like(rr, float("inf")), rr)
nlist0 = nlist
ret = {}
for rc, ns in zip(rcuts[::-1], nsels[::-1]):
for rc, ns in zip(rcuts[::-1], nsels[::-1], strict=True):
tnlist_1 = nlist0[:, :, :ns]
tnlist_1 = xp.where(rr[:, :, :ns] > rc, xp.full_like(tnlist_1, -1), tnlist_1)
ret[get_multiple_nlist_key(rc, ns)] = tnlist_1
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def eval(
zip(
[x.name for x in request_defs],
out,
strict=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def compute_input_stats(
sumn = []
sumv2 = []
for cc, bb, tt, nn, mm in zip(
data_coord, data_box, data_atype, natoms_vec, mesh
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
):
sysv, sysv2, sysn = self._compute_dstats_sys_nonsmth(cc, bb, tt, nn, mm)
sumv.append(sysv)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def compute_input_stats(
sumr2 = []
suma2 = []
for cc, bb, tt, nn, mm in zip(
data_coord, data_box, data_atype, natoms_vec, mesh
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
):
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
cc, bb, tt, nn, mm
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def init_variables(
start_index_old[0] = 0

for nn, oo, ii, jj in zip(
n_descpt, n_descpt_old, start_index, start_index_old
n_descpt, n_descpt_old, start_index, start_index_old, strict=True
):
if nn < oo:
# new size is smaller, copy part of std
Expand Down
8 changes: 7 additions & 1 deletion deepmd/tf/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,13 @@ def compute_input_stats(
sumr2 = []
suma2 = []
for cc, bb, tt, nn, mm, ee in zip(
data_coord, data_box, data_atype, natoms_vec, mesh, data_efield
data_coord,
data_box,
data_atype,
natoms_vec,
mesh,
data_efield,
strict=True,
):
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
cc, bb, tt, nn, mm, ee
Expand Down
10 changes: 8 additions & 2 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,13 @@ def compute_input_stats(
if mixed_type:
sys_num = 0
for cc, bb, tt, nn, mm, r_n in zip(
data_coord, data_box, data_atype, natoms_vec, mesh, real_natoms_vec
data_coord,
data_box,
data_atype,
natoms_vec,
mesh,
real_natoms_vec,
strict=True,
):
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
cc, bb, tt, nn, mm, mixed_type, r_n
Expand All @@ -392,7 +398,7 @@ def compute_input_stats(
suma2.append(sysa2)
else:
for cc, bb, tt, nn, mm in zip(
data_coord, data_box, data_atype, natoms_vec, mesh
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
):
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
cc, bb, tt, nn, mm
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def compute_input_stats(
sumn = []
sumr2 = []
for cc, bb, tt, nn, mm in zip(
data_coord, data_box, data_atype, natoms_vec, mesh
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
):
sysr, sysr2, sysn = self._compute_dstats_sys_se_r(cc, bb, tt, nn, mm)
sumr.append(sysr)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def compute_input_stats(
sumr2 = []
suma2 = []
for cc, bb, tt, nn, mm in zip(
data_coord, data_box, data_atype, natoms_vec, mesh
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
):
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
cc, bb, tt, nn, mm
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,8 @@ def eval(
output = (output,)

output_dict = {
odef.name: oo for oo, odef in zip(output, self.output_def.var_defs.values())
odef.name: oo
for oo, odef in zip(output, self.output_def.var_defs.values(), strict=True)
}
# ugly!!
if self.modifier_type is not None and issubclass(self.model_type, DeepPot):
Expand Down
10 changes: 5 additions & 5 deletions deepmd/tf/nvnmd/entrypoints/mapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,11 @@ def run_u2s(self):

u = N2 * np.reshape(np.arange(0, N + 1) / N, [-1, 1]) # pylint: disable=no-explicit-dtype
res_lst = run_sess(sess, vals, feed_dict={dic_ph["u"]: u})
res_dic = dict(zip(keys, res_lst))
res_dic = dict(zip(keys, res_lst, strict=True))

u2 = N2 * np.reshape(np.arange(0, N * 16 + 1) / (N * 16), [-1, 1]) # pylint: disable=no-explicit-dtype
res_lst2 = run_sess(sess, vals, feed_dict={dic_ph["u"]: u2})
res_dic2 = dict(zip(keys, res_lst2)) # reference for compare
res_dic2 = dict(zip(keys, res_lst2, strict=True)) # reference for compare

# change value
for tt in range(ndim):
Expand Down Expand Up @@ -536,11 +536,11 @@ def run_s2g(self):

s = N2 * np.reshape(np.arange(0, N + 1) / N, [-1, 1]) + smin_ # pylint: disable=no-explicit-dtype
res_lst = run_sess(sess, vals, feed_dict={dic_ph["s"]: s})
res_dic = dict(zip(keys, res_lst))
res_dic = dict(zip(keys, res_lst, strict=True))

s2 = N2 * np.reshape(np.arange(0, N * 16 + 1) / (N * 16), [-1, 1]) + smin_ # pylint: disable=no-explicit-dtype
res_lst2 = run_sess(sess, vals, feed_dict={dic_ph["s"]: s2})
res_dic2 = dict(zip(keys, res_lst2))
res_dic2 = dict(zip(keys, res_lst2, strict=True))

sess.close()
return res_dic, res_dic2
Expand Down Expand Up @@ -601,7 +601,7 @@ def run_t2g(self):
vals = list(dic_ph.values())
#
res_lst = run_sess(sess, vals, feed_dict={})
res_dic = dict(zip(keys, res_lst))
res_dic = dict(zip(keys, res_lst, strict=True))

sess.close()
return res_dic
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,4 +967,4 @@ def get_data_dict(self, batch_list: list[np.ndarray]) -> dict[str, np.ndarray]:
dict[str, np.ndarray]
The dict of the loaded data.
"""
return dict(zip(self.data_keys, batch_list))
return dict(zip(self.data_keys, batch_list, strict=True))
2 changes: 1 addition & 1 deletion deepmd/utils/batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def concate_result(r: list[Any]) -> Any:
return ret

if not returned_dict:
r_list = [concate_result(r) for r in zip(*results)]
r_list = [concate_result(r) for r in zip(*results, strict=True)]
r = tuple(r_list)
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
Expand Down
4 changes: 3 additions & 1 deletion deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def __init__(
rule = int(words[1])
filtered_data_systems = []
filtered_system_dirs = []
for sys_dir, data_sys in zip(self.system_dirs, self.data_systems):
for sys_dir, data_sys in zip(
self.system_dirs, self.data_systems, strict=True
):
if data_sys.get_natoms() <= rule:
filtered_data_systems.append(data_sys)
filtered_system_dirs.append(sys_dir)
Expand Down
5 changes: 3 additions & 2 deletions deepmd/utils/model_branch_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def as_table(self) -> str:
# Step 3: Determine actual width for each column
# For the first two columns, we already decided the exact widths above.
col_widths: list[int] = []
for idx, col in enumerate(zip(*wrapped_rows)):
for idx, col in enumerate(zip(*wrapped_rows, strict=True)):
if idx == 0:
col_widths.append(branch_col_width)
elif idx == 1:
Expand All @@ -187,7 +187,8 @@ def draw_row_line(cells_parts: list[list[str]]) -> str:
return (
"| "
+ " | ".join(
part.ljust(width) for part, width in zip(cells_parts, col_widths)
part.ljust(width)
for part, width in zip(cells_parts, col_widths, strict=True)
)
+ " |"
)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/utils/update_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def update_one_sel(
sel = [int(self.wrap_up_4(ii * ratio)) for ii in tmp_sel]
else:
# sel is set by user
for ii, (tt, dd) in enumerate(zip(tmp_sel, sel)):
# TODO: Fix len(tmp_sel) != len(sel) for TF spin models when strict is True
# error reported by source/tests/tf/test_init_frz_model_spin.py
for ii, (tt, dd) in enumerate(zip(tmp_sel, sel, strict=False)):
if dd and tt > dd:
# we may skip warning for sel=0, where the user is likely
# to exclude such type in the descriptor
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ select = [
"TCH", # flake8-type-checking
"PYI", # flake8-pyi
"ANN", # type annotations
"B905", # zip-without-explicit-strict
]

ignore = [
Expand Down Expand Up @@ -430,9 +431,9 @@ runtime-evaluated-base-classes = ["torch.nn.Module"]
"backend/**" = ["ANN"]
"data/**" = ["ANN"]
"deepmd/tf/**" = ["TID253", "ANN"]
"deepmd/pt/**" = ["TID253"]
"deepmd/pt/**" = ["TID253", "B905"]
"deepmd/jax/**" = ["TID253"]
"deepmd/pd/**" = ["TID253", "ANN"]
"deepmd/pd/**" = ["TID253", "ANN", "B905"]

"source/**" = ["ANN"]
"source/tests/tf/**" = ["TID253", "ANN"]
Expand Down
1 change: 1 addition & 0 deletions source/tests/common/dpmodel/test_pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None:
0.035,
0.025,
],
strict=True,
):
extended_coord = np.array(
[
Expand Down
Loading