@@ -9,44 +9,33 @@ public static partial class torchvision
9
9
{
10
10
internal class Normalize : ITransform , IDisposable
11
11
{
12
- internal Normalize ( double [ ] means , double [ ] stdevs , ScalarType dtype = ScalarType . Float32 , torch . Device ? device = null )
12
+ internal Normalize ( double [ ] means , double [ ] stdevs , bool inplace = false )
13
13
{
14
14
if ( means is null ) throw new ArgumentNullException ( nameof ( means ) ) ;
15
15
if ( stdevs is null ) throw new ArgumentNullException ( nameof ( stdevs ) ) ;
16
16
if ( means . Length != stdevs . Length )
17
17
throw new ArgumentException ( $ "{ nameof ( means ) } and { nameof ( stdevs ) } must be the same length in call to Normalize") ;
18
18
if ( means . Length != 1 && means . Length != 3 )
19
19
throw new ArgumentException ( $ "Since they correspond to the number of channels in an image, { nameof ( means ) } and { nameof ( stdevs ) } must both be either 1 or 3 long") ;
20
+ this . means = means ;
21
+ this . stdevs = stdevs ;
22
+ this . inplace = inplace ;
20
23
21
- this . means = means . ToTensor ( new long [ ] { 1 , means . Length , 1 , 1 } ) ; // Assumes NxCxHxW
22
- this . stdevs = stdevs . ToTensor ( new long [ ] { 1 , stdevs . Length , 1 , 1 } ) ; // Assumes NxCxHxW
23
-
24
- if ( dtype != ScalarType . Float64 ) {
25
- this . means = this . means . to_type ( dtype ) ;
26
- this . stdevs = this . stdevs . to_type ( dtype ) ;
27
- }
28
-
29
- if ( device != null && device . type != DeviceType . CPU ) {
30
- this . means = this . means . to ( device ) ;
31
- this . stdevs = this . stdevs . to ( device ) ;
32
- }
33
24
}
34
25
35
26
public Tensor call ( Tensor input )
36
27
{
37
- if ( means . size ( 1 ) != input . size ( 1 ) ) throw new ArgumentException ( "The number of channels is not equal to the number of means and standard deviations" ) ;
38
- return ( input - means ) / stdevs ;
28
+ return transforms . functional . normalize ( input , means , stdevs , inplace ) ;
39
29
}
40
30
41
- private Tensor means ;
42
- private Tensor stdevs ;
31
+ private readonly double [ ] means ;
32
+ private readonly double [ ] stdevs ;
33
+ private readonly bool inplace ;
43
34
bool disposedValue ;
44
35
45
36
protected virtual void Dispose ( bool disposing )
46
37
{
47
38
if ( ! disposedValue ) {
48
- means ? . Dispose ( ) ;
49
- stdevs ? . Dispose ( ) ;
50
39
disposedValue = true ;
51
40
}
52
41
}
@@ -72,12 +61,11 @@ public static partial class transforms
72
61
/// </summary>
73
62
/// <param name="means">Sequence of means for each channel.</param>
74
63
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
75
- /// <param name="dtype">Bool to make this operation inplace.</param>
76
- /// <param name="device">The device to place the output tensor on.</param>
64
+ /// <param name="inplace">Bool to make this operation inplace.</param>
77
65
/// <returns></returns>
78
- static public ITransform Normalize ( double [ ] means , double [ ] stdevs , ScalarType dtype = ScalarType . Float32 , torch . Device ? device = null )
66
+ static public ITransform Normalize ( double [ ] means , double [ ] stdevs , bool inplace = false )
79
67
{
80
- return new Normalize ( means , stdevs , dtype , device ) ;
68
+ return new Normalize ( means , stdevs , inplace ) ;
81
69
}
82
70
}
83
71
}
0 commit comments