Skip to content

Commit 90095c3

Browse files
authored
Merge pull request #579 from joker-star-l/tf_demo
ConvLSTM demo
2 parents 3917971 + de8f165 commit 90095c3

File tree

32 files changed

+2053
-152
lines changed

32 files changed

+2053
-152
lines changed

wayang-api/wayang-api-json/src/main/scala/operatorfromjson/binary/DLTrainingOperatorFromJson.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ case class Op(val op: String, val opType: String, val dType: String, val fromLis
4343
case "Cast" => new Cast(parseDType(dType))
4444
case "CrossEntropyLoss" => new CrossEntropyLoss(labels)
4545
case "Eq" => new Eq()
46-
case "Input" => new Input(parseInputType(opType))
46+
case "Input" => new Input(null, parseInputType(opType))
4747
case "Mean" => new Mean(dim)
4848
case "Linear" => new Linear(inFeatures, outFeatures, bias)
4949
case "ReLU" => new ReLU()
@@ -74,7 +74,6 @@ case class Op(val op: String, val opType: String, val dType: String, val fromLis
7474
inputType match {
7575
case "..FEATURES.." => Input.Type.FEATURES
7676
case "..LABEL.." => Input.Type.LABEL
77-
case "..PREDICTED.." => Input.Type.PREDICTED
7877
}
7978
}
8079
}

wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/model/DLModel.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,26 @@ public DLModel(Op out) {
3030
public Op getOut() {
3131
return out;
3232
}
33+
34+
public static class Builder {
35+
private Op out;
36+
37+
public DLModel build() {
38+
return new DLModel(out);
39+
}
40+
41+
public Builder layer(Op op) {
42+
if (op == null) {
43+
return this;
44+
}
45+
46+
if (out == null) {
47+
out = op;
48+
} else {
49+
out = op.with(out);
50+
}
51+
52+
return this;
53+
}
54+
}
3355
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model.op;
20+
21+
public class Get extends Op {
22+
// currently, only support String
23+
private final Object key;
24+
25+
public Get(Object key) {
26+
this(key, null, DType.FLOAT32);
27+
}
28+
29+
public Get(Object key, DType dType) {
30+
this(key, null, dType);
31+
}
32+
33+
public Get(Object key, String name) {
34+
this(key, name, DType.FLOAT32);
35+
}
36+
37+
public Get(Object key, String name, DType dType) {
38+
super(name, dType);
39+
this.key = key;
40+
}
41+
42+
public Object getKey() {
43+
return key;
44+
}
45+
46+
@Override
47+
public int inputsRequired() {
48+
return 1;
49+
}
50+
}

wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/model/op/Input.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,45 @@
1919
package org.apache.wayang.basic.model.op;
2020

2121
public class Input extends Op {
22+
private final int[] shape;
2223

23-
public Input() {
24-
this(DType.FLOAT32);
24+
public Input(int[] shape) {
25+
this(shape, DType.FLOAT32);
2526
}
2627

27-
public Input(String name) {
28-
this(name, DType.FLOAT32);
28+
public Input(int[] shape, String name) {
29+
this(shape, name, DType.FLOAT32);
2930
}
3031

31-
public Input(Type type) {
32-
this(type.getName(), DType.FLOAT32);
32+
public Input(int[] shape, Type type) {
33+
this(shape, type.getName(), DType.FLOAT32);
3334
}
3435

35-
public Input(DType dType) {
36-
super(dType);
36+
public Input(int[] shape, DType dType) {
37+
this(shape, (String) null, dType);
3738
}
3839

39-
public Input(String name, DType dType) {
40-
super(name, dType);
40+
public Input(int[] shape, Type type, DType dType) {
41+
this(shape, type.getName(), dType);
4142
}
4243

43-
public Input(Type type, DType dType) {
44-
super(type.getName(), dType);
44+
public Input(int[] shape, String name, DType dType) {
45+
super(name, dType);
46+
this.shape = shape;
4547
}
4648

4749
@Override
4850
public int inputsRequired() {
4951
return 0;
5052
}
5153

54+
public int[] getShape() {
55+
return shape;
56+
}
57+
5258
public enum Type {
5359
FEATURES("..FEATURES.."),
54-
LABEL("..LABEL.."),
55-
PREDICTED("..PREDICTED..");
60+
LABEL("..LABEL..");
5661

5762
private final String name;
5863

wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/model/op/Op.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,35 @@
2525
import java.util.concurrent.atomic.AtomicInteger;
2626

2727
public abstract class Op implements Serializable {
28+
private static final AtomicInteger ID_GENERATOR = new AtomicInteger(0);
2829
private static final AtomicInteger CNT = new AtomicInteger(0);
2930

31+
protected final int id;
3032
protected final String name;
3133
protected final List<Op> fromList;
3234

3335
// output type
3436
protected final DType dType;
3537

3638
public Op(DType dType) {
37-
this.name = this.getClass().getSimpleName() + CNT.getAndIncrement();
38-
this.fromList = new ArrayList<>();
39-
this.dType = dType;
39+
this(null, dType);
4040
}
4141

4242
public Op(String name, DType dType) {
43-
this.name = name;
43+
this.id = ID_GENERATOR.getAndIncrement();
44+
if (name == null || name.isEmpty()) {
45+
this.name = this.getClass().getSimpleName() + CNT.getAndIncrement();
46+
} else {
47+
this.name = name;
48+
}
4449
this.fromList = new ArrayList<>();
4550
this.dType = dType;
4651
}
4752

53+
public int getId() {
54+
return id;
55+
}
56+
4857
public String getName() {
4958
return name;
5059
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model.op;
20+
21+
public class Reshape extends Op {
22+
private final int[] shape;
23+
24+
public Reshape(int[] shape) {
25+
this(shape, null, DType.FLOAT32);
26+
}
27+
28+
public Reshape(int[] shape, DType dType) {
29+
this(shape, null, dType);
30+
}
31+
32+
public Reshape(int[] shape, String name) {
33+
this(shape, name, DType.FLOAT32);
34+
}
35+
36+
public Reshape(int[] shape, String name, DType dType) {
37+
super(name, dType);
38+
this.shape = shape;
39+
}
40+
41+
public int[] getShape() {
42+
return shape;
43+
}
44+
45+
@Override
46+
public int inputsRequired() {
47+
return 1;
48+
}
49+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model.op;
20+
21+
public class Slice extends Op {
22+
private final int[][] range; // int[dim][2]
23+
24+
public Slice(int[][] range) {
25+
this(range, null, DType.FLOAT32);
26+
}
27+
28+
public Slice(int[][] range, DType dType) {
29+
this(range, null, dType);
30+
}
31+
32+
public Slice(int[][] range, String name) {
33+
this(range, name, DType.FLOAT32);
34+
}
35+
36+
public Slice(int[][] range, String name, DType dType) {
37+
super(name, dType);
38+
this.range = range;
39+
}
40+
41+
public int[][] getRange() {
42+
return range;
43+
}
44+
45+
@Override
46+
public int inputsRequired() {
47+
return 1;
48+
}
49+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model.op;
20+
21+
public class Transpose extends Op {
22+
private final int[] perm;
23+
24+
public Transpose(int[] perm) {
25+
this(perm, null, DType.FLOAT32);
26+
}
27+
28+
public Transpose(int[] perm, DType dType) {
29+
this(perm, null, dType);
30+
}
31+
32+
public Transpose(int[] perm, String name) {
33+
this(perm, name, DType.FLOAT32);
34+
}
35+
36+
public Transpose(int[] perm, String name, DType dType) {
37+
super(name, dType);
38+
this.perm = perm;
39+
}
40+
41+
public int[] getPerm() {
42+
return perm;
43+
}
44+
45+
@Override
46+
public int inputsRequired() {
47+
return 1;
48+
}
49+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model.op;
20+
21+
public class ZeroLike extends Op {
22+
23+
public ZeroLike() {
24+
this(null, DType.FLOAT32);
25+
}
26+
27+
public ZeroLike(DType dType) {
28+
this(null, dType);
29+
}
30+
31+
public ZeroLike(String name) {
32+
this(name, DType.FLOAT32);
33+
}
34+
35+
public ZeroLike(String name, DType dType) {
36+
super(name, dType);
37+
}
38+
39+
@Override
40+
public int inputsRequired() {
41+
return 1;
42+
}
43+
}

0 commit comments

Comments
 (0)