Skip to content

Commit 82abd16

Browse files
committed
Add a input shape check for DeePKS model.
1 parent 4db39e9 commit 82abd16

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

source/source_lcao/module_deepks/deepks_basic.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,15 @@ void DeePKS_domain::cal_edelta_gedm(const int nat,
197197
// input_dim:(natom, des_per_atom)
198198
inputs.push_back(torch::cat(descriptor, 0).reshape({1, nat, des_per_atom}));
199199
std::vector<torch::Tensor> ec;
200-
ec.push_back(model_deepks.forward(inputs).toTensor()); // Hartree
200+
try
201+
{
202+
ec.push_back(model_deepks.forward(inputs).toTensor()); // Hartree
203+
}
204+
catch (const c10::Error& e)
205+
{
206+
ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm", "Please check whether the input shape required by model file matches the descriptor!");
207+
throw;
208+
}
201209
E_delta = ec[0].item<double>() * 2; // Ry; *2 is for Hartree to Ry
202210

203211
// cal gedm

0 commit comments

Comments
 (0)