Skip to content

Commit 46e190d

Browse files
committed
feat: add RNN basic framework.
1 parent e9f2cac commit 46e190d

File tree

88 files changed

+1789
-188
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+1789
-188
lines changed

src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs renamed to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Extensions
6+
namespace Tensorflow.Common.Extensions
77
{
88
public static class JObjectExtensions
99
{
1010
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
1111
{
1212
var res = obj[key];
13-
if(res is null)
13+
if (res is null)
1414
{
15-
return default(T);
15+
return default;
1616
}
1717
else
1818
{
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class LinqExtensions
9+
{
10+
#if NETSTANDARD2_0
11+
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
12+
{
13+
return sequence.Skip(sequence.Count() - count);
14+
}
15+
16+
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
17+
{
18+
return sequence.Take(sequence.Count() - count);
19+
}
20+
#endif
21+
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
22+
{
23+
return new Tensors(tensors);
24+
}
25+
}
26+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Types
7+
{
8+
public class GeneralizedTensorShape: IEnumerable<long?[]>
9+
{
10+
public TensorShapeConfig[] Shapes { get; set; }
11+
/// <summary>
12+
/// create a single-dim generalized Tensor shape.
13+
/// </summary>
14+
/// <param name="dim"></param>
15+
public GeneralizedTensorShape(int dim)
16+
{
17+
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
18+
}
19+
20+
public GeneralizedTensorShape(Shape shape)
21+
{
22+
Shapes = new TensorShapeConfig[] { shape };
23+
}
24+
25+
public GeneralizedTensorShape(TensorShapeConfig shape)
26+
{
27+
Shapes = new TensorShapeConfig[] { shape };
28+
}
29+
30+
public GeneralizedTensorShape(TensorShapeConfig[] shapes)
31+
{
32+
Shapes = shapes;
33+
}
34+
35+
public GeneralizedTensorShape(IEnumerable<Shape> shape)
36+
{
37+
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
38+
}
39+
40+
public Shape ToSingleShape()
41+
{
42+
if (Shapes.Length != 1)
43+
{
44+
throw new ValueError("The generalized shape contains more than 1 dim.");
45+
}
46+
var shape_config = Shapes[0];
47+
Debug.Assert(shape_config is not null);
48+
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
49+
}
50+
51+
public long ToNumber()
52+
{
53+
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
54+
{
55+
throw new ValueError("The generalized shape contains more than 1 dim.");
56+
}
57+
var res = Shapes[0].Items[0];
58+
return res is null ? -1 : res.Value;
59+
}
60+
61+
public Shape[] ToShapeArray()
62+
{
63+
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
64+
}
65+
66+
public IEnumerator<long?[]> GetEnumerator()
67+
{
68+
foreach (var shape in Shapes)
69+
{
70+
yield return shape.Items;
71+
}
72+
}
73+
74+
IEnumerator IEnumerable.GetEnumerator()
75+
{
76+
return GetEnumerator();
77+
}
78+
}
79+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This interface is used when some corresponding python methods have optional args.
9+
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
10+
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
11+
/// as the parameter of the method.
12+
/// </summary>
13+
public interface IOptionalArgs
14+
{
15+
/// <summary>
16+
/// The identifier of the class. It is not an argument but only something to
17+
/// separate different OptionalArgs.
18+
/// </summary>
19+
string Identifier { get; }
20+
}
21+
}

src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs renamed to src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55

6-
namespace Tensorflow.Keras.Saving
6+
namespace Tensorflow.Common.Types
77
{
88
public class TensorShapeConfig
99
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
using Newtonsoft.Json;
22
using System.Collections.Generic;
3+
using Tensorflow.Keras.Layers.Rnn;
34

45
namespace Tensorflow.Keras.ArgsDefinition.Rnn
56
{
7+
// TODO(Rinne): add regularizers.
68
public class RNNArgs : AutoSerializeLayerArgs
79
{
8-
public interface IRnnArgCell : ILayer
9-
{
10-
object state_size { get; }
11-
}
1210
[JsonProperty("cell")]
1311
// TODO: the cell should be serialized with `serialize_keras_object`.
14-
public IRnnArgCell Cell { get; set; } = null;
12+
public IRnnCell Cell { get; set; } = null;
1513
[JsonProperty("return_sequences")]
1614
public bool ReturnSequences { get; set; } = false;
1715
[JsonProperty("return_state")]
@@ -34,6 +32,9 @@ public interface IRnnArgCell : ILayer
3432
public IInitializer KernelInitializer { get; set; }
3533
public IInitializer RecurrentInitializer { get; set; }
3634
public IInitializer BiasInitializer { get; set; }
35+
public float Dropout { get; set; } = .0f;
36+
public bool ZeroOutputForMask { get; set; } = false;
37+
public float RecurrentDropout { get; set; } = .0f;
3738

3839
// kernel_regularizer=None,
3940
// recurrent_regularizer=None,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Common.Types;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
7+
{
8+
public class RnnOptionalArgs: IOptionalArgs
9+
{
10+
public string Identifier => "Rnn";
11+
public Tensor Mask { get; set; } = null;
12+
public Tensors Constants { get; set; } = null;
13+
}
14+
}

0 commit comments

Comments
 (0)