Skip to content

Commit 0a8503d

Browse files
committed
Prevent replacing DALI op by constant in the constant folding process (#367)
During constant folding process, TensorFlow checks if nodes that have constant input produces constant outputs and replaces them with constants. In our case, it can happen as DALI op provides a deterministic output when running once just after construction. Also, DALI op provides CPU variant so now TensorFlow can run DALI on CPU for constant folding examination. To prevent this DALI op needs to tell that it is stateful. Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 1e166d9 commit 0a8503d

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

dali/tensorflow/daliop.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ REGISTER_OP("Dali")
6565
.Attr("prefetch_queue_depth: int = 2")
6666
.Output("data: dtypes")
6767
.Attr("dtypes: list({half, float, uint8, int16, int32, int64}) >= 1")
68+
// To prevent replacing DALI op with constant tensor during TF constant folding process
69+
.SetIsStateful()
6870
.SetShapeFn([](tf::shape_inference::InferenceContext* c) {
6971
std::vector<tf::PartialTensorShape> shapes;
7072
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));

0 commit comments

Comments
 (0)