Skip to content

Commit 6db324b

Browse files
authored
Merge pull request #592 from joker-star-l/dynamic-shape
support dynamic batchsize on tensorflow
2 parents bcf2835 + 7cfb365 commit 6db324b

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/op/nn/TensorflowBatchNorm3D.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
import org.tensorflow.Graph;
2424
import org.tensorflow.Operand;
2525
import org.tensorflow.op.Ops;
26+
import org.tensorflow.op.core.Shape;
2627
import org.tensorflow.types.TBool;
28+
import org.tensorflow.types.TInt32;
2729
import org.tensorflow.types.family.TNumber;
2830

31+
import java.util.Arrays;
32+
2933
public class TensorflowBatchNorm3D<T extends TNumber> {
3034
private final Ops tf;
3135
private final BatchNorm3D op;
@@ -44,9 +48,18 @@ public TensorflowBatchNorm3D(Graph graph, Ops tf, BatchNorm3D op, Class<T> tClas
4448
}
4549

4650
public Operand<T> call(Operand<T> input, Operand<TBool> trainingMode) {
47-
long[] s = input.shape().asArray(); // N, C, D, H, W
48-
Operand<T> input2D = tf.reshape(input, tf.array(s[0], s[1], s[2], -1)); // N, C, D, H * W
51+
// input: N, C, D, H, W
52+
Shape<TInt32> inputShape = tf.shape(input);
53+
Operand<TInt32> square = tf.math.mul(tf.shape.size(inputShape, tf.constant(3)), tf.shape.size(inputShape, tf.constant(4)));
54+
Operand<TInt32> newShape = tf.concat(
55+
Arrays.asList(
56+
tf.shape.take(inputShape, tf.constant(3)),
57+
square
58+
),
59+
tf.constant(0)
60+
); // N, C, D, H * W
61+
Operand<T> input2D = tf.reshape(input, newShape);
4962
Operand<T> output = batchNorm2D.call(input2D, trainingMode);
50-
return tf.withName(op.getName()).reshape(output, tf.constant(s));
63+
return tf.withName(op.getName()).reshape(output, inputShape);
5164
}
5265
}

wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/op/nn/TensorflowConvLSTM2D.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.tensorflow.Output;
2525
import org.tensorflow.op.Ops;
2626
import org.tensorflow.op.core.Stack;
27+
import org.tensorflow.types.TInt32;
2728
import org.tensorflow.types.family.TNumber;
2829

2930
import java.util.ArrayList;
@@ -45,15 +46,21 @@ public TensorflowConvLSTM2D(Ops tf, ConvLSTM2D op, Class<T> tClass) {
4546

4647
public Operand<?> call(Operand<T> input) {
4748
// input: [batch_size, time_step, input_dim, height, width]
48-
long batchSize = input.shape().get(0);
49-
long seqLen = input.shape().get(1);
50-
long height = input.shape().get(3);
51-
long width = input.shape().get(4);
52-
53-
Operand<T> h = tf.zeros(tf.array(batchSize, op.getHiddenDim(), height, width), tClass);
54-
Operand<T> c = tf.zeros(tf.array(batchSize, op.getHiddenDim(), height, width), tClass);
49+
Operand<TInt32> shape = tf.concat(
50+
Arrays.asList(
51+
tf.shape.size(tf.shape(input), tf.constant(0)), // batch_size
52+
tf.array(op.getHiddenDim()), // hidden_dim
53+
tf.shape.size(tf.shape(input), tf.constant(3)), // height
54+
tf.shape.size(tf.shape(input), tf.constant(4)) // width
55+
),
56+
tf.constant(0)
57+
);
58+
59+
Operand<T> h = tf.zeros(shape, tClass);
60+
Operand<T> c = tf.zeros(shape, tClass);
5561

5662
String outKey = op.getOutput();
63+
long seqLen = input.shape().get(1);
5764
List<Operand<T>> outputs = new ArrayList<>((int) seqLen);
5865

5966
for (long t = 0; t < seqLen; t++) {

wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowConvLSTMIT.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.wayang.api.DLTrainingDataQuantaBuilder;
2222
import org.apache.wayang.api.JavaPlanBuilder;
2323
import org.apache.wayang.api.LoadCollectionDataQuantaBuilder;
24+
import org.apache.wayang.api.PredictDataQuantaBuilder;
2425
import org.apache.wayang.basic.model.DLModel;
2526
import org.apache.wayang.basic.model.op.*;
2627
import org.apache.wayang.basic.model.op.nn.*;
@@ -32,22 +33,19 @@
3233
import org.apache.wayang.tensorflow.Tensorflow;
3334
import org.junit.Test;
3435

35-
import java.util.ArrayList;
36-
import java.util.Collection;
37-
import java.util.List;
38-
import java.util.Random;
36+
import java.util.*;
3937

4038
/**
4139
* Test the Tensorflow ConvLSTM integration with Wayang.
4240
*/
4341
public class TensorflowConvLSTMIT {
44-
private int inputDim = 1;
42+
private int inputDim = 2;
4543
private int hiddenDim = 64;
46-
private int outputDim = 1;
47-
private int inputFrames = 8;
44+
private int outputDim = 2;
45+
private int inputFrames = 6;
4846
private int outputFrames = 3;
49-
private int height = 16;
50-
private int width = 16;
47+
private int height = 17;
48+
private int width = 29;
5149

5250
private int batchSize = 16;
5351

@@ -59,8 +57,8 @@ public void test() {
5957
int[] stride = new int[]{1};
6058
int numLayers = 3;
6159

62-
Input features = new Input(new int[]{batchSize, inputFrames, inputDim, height, width}, Input.Type.FEATURES);
63-
Input labels = new Input(new int[]{batchSize, outputFrames, outputDim, height, width}, Input.Type.LABEL);
60+
Input features = new Input(new int[]{-1, inputFrames, inputDim, height, width}, Input.Type.FEATURES);
61+
Input labels = new Input(new int[]{-1, outputFrames, outputDim, height, width}, Input.Type.LABEL);
6462

6563
int[] perm = new int[]{0, 2, 1, 3, 4};
6664

@@ -79,8 +77,8 @@ public void test() {
7977
;
8078
}
8179
builder.layer(new Slice(new int[][]{{0, -1}, {inputFrames - outputFrames, -1}, {0, -1}, {0, -1}, {0, -1}})) // Input only the last outputFrames from ConvLSTM
82-
.layer(new Reshape(new int[]{batchSize, -1, height, width}))
83-
.layer(new Conv2D(hiddenDim * outputFrames, outputFrames, kernelSize, stride, "SAME", true))
80+
.layer(new Reshape(new int[]{-1, hiddenDim * outputFrames, height, width}))
81+
.layer(new Conv2D(hiddenDim * outputFrames, outputDim * outputFrames, kernelSize, stride, "SAME", true))
8482
// .layer(new Transpose(perm)) // change channels and timeStep
8583
// .layer(new Conv3D(hiddenDim, outputDim, new int[]{3, 3, 3}, stride, "SAME", true)) // FIXME: The gradient of conv3D cannot be calculated, use conv2D as a substitute.
8684
// .layer(new Transpose(perm)) // change channels and timeStep
@@ -110,13 +108,17 @@ public void test() {
110108
JavaPlanBuilder plan = new JavaPlanBuilder(wayangContext);
111109

112110
LoadCollectionDataQuantaBuilder<float[][][][]> X = plan.loadCollection(mockData(inputFrames, inputDim));
111+
LoadCollectionDataQuantaBuilder<float[][][][]> XTest = plan.loadCollection(mockData(inputFrames, inputDim));
113112
LoadCollectionDataQuantaBuilder<float[][][][]> Y = plan.loadCollection(mockData(outputFrames, outputDim));
114113

115114
DLTrainingDataQuantaBuilder<float[][][][], float[][][][]> trainingOperator = X.dlTraining(Y, model, option);
116115

117-
Collection<DLModel> trainedModel = trainingOperator.collect();
116+
// Collection<DLModel> trainedModel = trainingOperator.collect();
117+
// System.out.println(trainedModel);
118118

119-
System.out.println(trainedModel);
119+
PredictDataQuantaBuilder<float[][][][], float[][][][]> predictOperator = trainingOperator.predict(XTest, float[][][][].class);
120+
Collection<float[][][][]> predicted = predictOperator.collect();
121+
System.out.println(Arrays.deepToString(predicted.iterator().next()));
120122
}
121123

122124
/**

0 commit comments

Comments
 (0)