Skip to content

Commit e3d4da2

Browse files
committed
Add sum cost to Arguments
1 parent 87170a7 commit e3d4da2

File tree

3 files changed

+8
-0
lines changed

3 files changed

+8
-0
lines changed

paddle/api/Arguments.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) {
137137
a.cpuSequenceDims = m->cast<paddle::IVector>(vec->getSharedPtr());
138138
}
139139

140+
float Arguments::sumCosts() const {
141+
return paddle::Argument::sumCosts(m->outputs);
142+
}
143+
140144
int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) {
141145
auto& a = m->getArg(idx);
142146
return a.getBatchSize();

paddle/api/PaddleAPI.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ class Arguments {
454454
IVector* vec) throw(RangeError);
455455
void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError);
456456

457+
float sumCosts() const;
458+
457459
private:
458460
static Arguments* createByPaddleArgumentVector(void* ptr);
459461
void* getInternalArgumentsPtr() const;

paddle/api/test/testArguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def test_load_arguments(self):
2222
args = swig_paddle.Arguments.createArguments(1)
2323
args.setSlotValue(0, m)
2424

25+
self.assertAlmostEqual(27.0, args.sumCosts())
26+
2527
mat = args.getSlotValue(0)
2628
assert isinstance(mat, swig_paddle.Matrix)
2729
np_mat = mat.toNumpyMatInplace()

0 commit comments

Comments
 (0)