Skip to content

Commit d732808

Browse files
committed
Merge branch 'TDDFT_GPU_phase_1' of github.com:AsTonyshment/abacus-develop into TDDFT_GPU_phase_1
2 parents fbe01cd + 5044ac5 commit d732808

File tree

4 files changed

+42
-9
lines changed

4 files changed

+42
-9
lines changed

source/module_base/module_container/ATen/core/tensor.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ Tensor::Tensor(Tensor&& other) noexcept
5353
// However, Our subclass TensorMap, etc., do not own resources.
5454
// So, we do not need to declare a virtual destructor here.
5555
Tensor::~Tensor() {
56-
if (buffer_) buffer_->unref();
56+
if (buffer_) { buffer_->unref();
57+
}
5758
}
5859

5960
// Get the data type of the tensor.
@@ -223,7 +224,8 @@ Tensor& Tensor::operator=(const Tensor& other) {
223224
this->device_ = other.device_;
224225
this->data_type_ = other.data_type_;
225226
this->shape_ = other.shape_;
226-
if (buffer_) buffer_->unref();
227+
if (buffer_) { buffer_->unref();
228+
}
227229

228230
this->buffer_ = new TensorBuffer(GetAllocator(device_), shape_.NumElements() * SizeOfType(data_type_));
229231

@@ -241,7 +243,8 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
241243
this->data_type_ = other.data_type_;
242244
this->shape_ = other.shape_;
243245

244-
if (buffer_) buffer_->unref(); // Release current resource
246+
if (buffer_) { buffer_->unref(); // Release current resource
247+
}
245248
this->buffer_ = other.buffer_;
246249
other.buffer_ = nullptr; // Reset the other TensorBuffer.
247250
return *this;
@@ -284,7 +287,8 @@ bool Tensor::AllocateFrom(const Tensor& other, const TensorShape& shape) {
284287
data_type_ = other.data_type_;
285288
device_ = other.device_;
286289
shape_ = shape;
287-
if (buffer_) buffer_->unref();
290+
if (buffer_) { buffer_->unref();
291+
}
288292
buffer_ = new TensorBuffer(GetAllocator(device_), shape_.NumElements() * SizeOfType(data_type_));
289293
return true;
290294
}
@@ -324,6 +328,7 @@ Tensor Tensor::operator[](const int& index) const {
324328
// Overloaded operator<< for the Tensor class.
325329
std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
326330
std::ios::fmtflags flag(os.flags());
331+
std::streamsize precision = os.precision(); // save the current precision
327332
const int64_t num_elements = tensor.NumElements();
328333
const DataType data_type = tensor.data_type();
329334
const DeviceType device_type = tensor.device_type();
@@ -398,6 +403,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
398403
#endif
399404
// restore the os settings
400405
os.flags(flag);
406+
os.precision(precision); // restore the precision
401407
return os;
402408
}
403409

source/module_cell/read_stru.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
namespace unitcell
66
{
77
bool check_tau(const Atom* atoms,
8-
const int ntype,
9-
const int lat0)
8+
const int& ntype,
9+
const double& lat0)
1010
{
1111
ModuleBase::TITLE("UnitCell","check_tau");
1212
ModuleBase::timer::tick("UnitCell","check_tau");

source/module_cell/read_stru.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
namespace unitcell
77
{
88
bool check_tau(const Atom* atoms,
9-
const int ntype,
10-
const int lat0);
9+
const int& ntype,
10+
const double& lat0);
1111

1212
bool read_atom_species(std::ifstream& ifa,
1313
std::ofstream& ofs_running,

source/module_cell/test/unitcell_test.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ TEST_F(UcellTest, CheckDTau)
753753
}
754754
}
755755

756-
TEST_F(UcellTest, CheckTau)
756+
TEST_F(UcellTest, CheckTauFalse)
757757
{
758758
UcellTestPrepare utp = UcellTestLib["C1H2-CheckTau"];
759759
PARAM.input.relax_new = utp.relax_new;
@@ -769,6 +769,33 @@ TEST_F(UcellTest, CheckTau)
769769
remove("checktau_warning");
770770
}
771771

772+
TEST_F(UcellTest, CheckTauTrue)
773+
{
774+
UcellTestPrepare utp = UcellTestLib["C1H2-CheckTau"];
775+
PARAM.input.relax_new = utp.relax_new;
776+
ucell = utp.SetUcellInfo();
777+
GlobalV::ofs_warning.open("checktau_warning");
778+
int atom=0;
779+
//cause the ucell->lat0 is 0.5,if the type of the check_tau has
780+
//an int type,it will set to zero,and it will not pass the unittest
781+
ucell->lat0=0.5;
782+
ucell->nat=3;
783+
for (int it=0;it<ucell->ntype;it++)
784+
{
785+
for(int ia=0; ia<ucell->atoms[it].na; ++ia)
786+
{
787+
788+
for (int i=0;i<3;i++)
789+
{
790+
ucell->atoms[it].tau[ia][i]=((atom+i)/(ucell->nat*3.0));
791+
}
792+
atom+=3;
793+
}
794+
}
795+
EXPECT_EQ(unitcell::check_tau(ucell->atoms ,ucell->ntype, ucell->lat0),true);
796+
GlobalV::ofs_warning.close();
797+
}
798+
772799
TEST_F(UcellTest, SelectiveDynamics)
773800
{
774801
UcellTestPrepare utp = UcellTestLib["C1H2-SD"];

0 commit comments

Comments
 (0)