1- import torch
2- import fast
3- import numpy as np
4- import pytest
5-
6-
7- def test_torch_tensor_fast_image_2d ():
8- types_to_test = [
9- (torch .uint8 , fast .TYPE_UINT8 , 255 ),
10- (torch .float32 , fast .TYPE_FLOAT , 1 ),
11- (torch .uint16 , fast .TYPE_UINT16 , 128 ),
12- (torch .int16 , fast .TYPE_INT16 , 128 )
13- ]
14- width = 64
15- height = 37
16-
17- for torch_type , fast_type , scale in types_to_test :
18- tensor = (torch .rand (height , width , dtype = torch .float32 )* scale ).to (torch_type )
19- image = fast .Image .createFromTensor (tensor )
20-
21- assert image .getWidth () == width
22- assert image .getHeight () == height
23- assert image .getNrOfChannels () == 1
24- assert image .getDataType () == fast_type
25- assert torch .equal (tensor , torch .tensor (np .asarray (image )).squeeze ())
26-
27- for channels in range (1 , 5 ):
28- tensor = (torch .rand (channels , height , width , dtype = torch .float32 )* scale ).to (torch_type )
1+ try :
2+ import torch
3+ import fast
4+ import numpy as np
5+ import pytest
6+
7+
8+ def test_torch_tensor_fast_image_2d ():
9+ types_to_test = [
10+ (torch .uint8 , fast .TYPE_UINT8 , 255 ),
11+ (torch .float32 , fast .TYPE_FLOAT , 1 ),
12+ (torch .uint16 , fast .TYPE_UINT16 , 128 ),
13+ (torch .int16 , fast .TYPE_INT16 , 128 )
14+ ]
15+ width = 64
16+ height = 37
17+
18+ for torch_type , fast_type , scale in types_to_test :
19+ tensor = (torch .rand (height , width , dtype = torch .float32 )* scale ).to (torch_type )
2920 image = fast .Image .createFromTensor (tensor )
3021
3122 assert image .getWidth () == width
3223 assert image .getHeight () == height
33- assert image .getNrOfChannels () == channels
24+ assert image .getNrOfChannels () == 1
3425 assert image .getDataType () == fast_type
35- assert torch .equal (tensor , torch .tensor (np .asarray (image )).permute ((2 , 0 , 1 )))
36-
37-
38- def test_torch_tensor_fast_image_3d ():
39- types_to_test = [
40- (torch .uint8 , fast .TYPE_UINT8 , 255 ),
41- (torch .float32 , fast .TYPE_FLOAT , 1 ),
42- (torch .uint16 , fast .TYPE_UINT16 , 128 ),
43- (torch .int16 , fast .TYPE_INT16 , 128 )
44- ]
45- width = 64
46- height = 37
47- depth = 42
48-
49- for torch_type , fast_type , scale in types_to_test :
50- tensor = (torch .rand (depth , height , width , dtype = torch .float32 )* scale ).to (torch_type )
51- image = fast .Image .createFromTensor (tensor )
52-
53- assert image .getWidth () == width
54- assert image .getHeight () == height
55- assert image .getDepth () == depth
56- assert image .getNrOfChannels () == 1
57- assert image .getDataType () == fast_type
58- assert torch .equal (tensor , torch .tensor (np .asarray (image )).squeeze ())
59-
60- for channels in range (1 , 5 ):
61- tensor = (torch .rand (channels , depth , height , width , dtype = torch .float32 )* scale ).to (torch_type )
26+ assert torch .equal (tensor , torch .tensor (np .asarray (image )).squeeze ())
27+
28+ for channels in range (1 , 5 ):
29+ tensor = (torch .rand (channels , height , width , dtype = torch .float32 )* scale ).to (torch_type )
30+ image = fast .Image .createFromTensor (tensor )
31+
32+ assert image .getWidth () == width
33+ assert image .getHeight () == height
34+ assert image .getNrOfChannels () == channels
35+ assert image .getDataType () == fast_type
36+ assert torch .equal (tensor , torch .tensor (np .asarray (image )).permute ((2 , 0 , 1 )))
37+
38+
39+ def test_torch_tensor_fast_image_3d ():
40+ types_to_test = [
41+ (torch .uint8 , fast .TYPE_UINT8 , 255 ),
42+ (torch .float32 , fast .TYPE_FLOAT , 1 ),
43+ (torch .uint16 , fast .TYPE_UINT16 , 128 ),
44+ (torch .int16 , fast .TYPE_INT16 , 128 )
45+ ]
46+ width = 64
47+ height = 37
48+ depth = 42
49+
50+ for torch_type , fast_type , scale in types_to_test :
51+ tensor = (torch .rand (depth , height , width , dtype = torch .float32 )* scale ).to (torch_type )
6252 image = fast .Image .createFromTensor (tensor )
6353
6454 assert image .getWidth () == width
6555 assert image .getHeight () == height
6656 assert image .getDepth () == depth
67- assert image .getNrOfChannels () == channels
57+ assert image .getNrOfChannels () == 1
6858 assert image .getDataType () == fast_type
69- assert torch .equal (tensor , torch .tensor (np .asarray (image )).permute ((3 , 0 , 1 , 2 )))
70-
71-
72- def test_torch_tensor_to_image_exceptions ():
73- # Have to use createFromTensor not createFromArray
74- tensor = torch .rand ((1 , 32 , 32 ), dtype = torch .float32 )
75- with pytest .raises (ValueError ):
76- fast .Image .createFromArray (tensor )
77-
78- # Only 1 dim
79- tensor = torch .rand ((7 ,), dtype = torch .float32 )
80- with pytest .raises (ValueError ):
81- fast .Image .createFromTensor (tensor )
82-
83- # More than 4 dims
84- tensor = torch .rand ((7 , 32 , 13 , 23 , 54 ), dtype = torch .float32 )
85- with pytest .raises (ValueError ):
86- fast .Image .createFromTensor (tensor )
87-
88- # Incorrect type
89- tensor = torch .rand ((1 , 32 , 32 ), dtype = torch .float64 )
90- with pytest .raises (TypeError ):
91- fast .Image .createFromTensor (tensor )
92- with pytest .raises (ValueError ):
93- fast .Image .createFromTensor ('' )
94-
95-
96- def test_torch_tensor_to_tensor_exceptions ():
97- # Incorrect type
98- tensor = torch .rand ((1 , 32 , 32 ), dtype = torch .float32 )
99- with pytest .raises (ValueError ):
100- fast .Tensor .createFromArray (tensor )
101- with pytest .raises (ValueError ):
102- fast .Tensor .createFromTensor ('' )
103-
104-
105- def test_torch_tensor_fast_tensor ():
106- types_to_test = [
107- (torch .uint8 , 255 ),
108- (torch .int8 , 127 ),
109- (torch .float32 , 1 ),
110- (torch .float64 , 1 ),
111- (torch .uint16 , 128 ),
112- (torch .int16 , 128 )
113- ]
114- for type , scale in types_to_test :
115- shape = (23 ,)
116- tensor = (torch .rand (shape )* scale ).to (type )
117- fast_tensor = fast .Tensor .createFromTensor (tensor )
59+ assert torch .equal (tensor , torch .tensor (np .asarray (image )).squeeze ())
60+
61+ for channels in range (1 , 5 ):
62+ tensor = (torch .rand (channels , depth , height , width , dtype = torch .float32 )* scale ).to (torch_type )
63+ image = fast .Image .createFromTensor (tensor )
64+
65+ assert image .getWidth () == width
66+ assert image .getHeight () == height
67+ assert image .getDepth () == depth
68+ assert image .getNrOfChannels () == channels
69+ assert image .getDataType () == fast_type
70+ assert torch .equal (tensor , torch .tensor (np .asarray (image )).permute ((3 , 0 , 1 , 2 )))
71+
72+
73+ def test_torch_tensor_to_image_exceptions ():
74+ # Have to use createFromTensor not createFromArray
75+ tensor = torch .rand ((1 , 32 , 32 ), dtype = torch .float32 )
76+ with pytest .raises (ValueError ):
77+ fast .Image .createFromArray (tensor )
78+
79+ # Only 1 dim
80+ tensor = torch .rand ((7 ,), dtype = torch .float32 )
81+ with pytest .raises (ValueError ):
82+ fast .Image .createFromTensor (tensor )
83+
84+ # More than 4 dims
85+ tensor = torch .rand ((7 , 32 , 13 , 23 , 54 ), dtype = torch .float32 )
86+ with pytest .raises (ValueError ):
87+ fast .Image .createFromTensor (tensor )
88+
89+ # Incorrect type
90+ tensor = torch .rand ((1 , 32 , 32 ), dtype = torch .float64 )
91+ with pytest .raises (TypeError ):
92+ fast .Image .createFromTensor (tensor )
93+ with pytest .raises (ValueError ):
94+ fast .Image .createFromTensor ('' )
95+
96+
97+ def test_torch_tensor_to_tensor_exceptions ():
98+ # Incorrect type
99+ tensor = torch .rand ((1 , 32 , 32 ), dtype = torch .float32 )
100+ with pytest .raises (ValueError ):
101+ fast .Tensor .createFromArray (tensor )
102+ with pytest .raises (ValueError ):
103+ fast .Tensor .createFromTensor ('' )
104+
105+
106+ def test_torch_tensor_fast_tensor ():
107+ types_to_test = [
108+ (torch .uint8 , 255 ),
109+ (torch .int8 , 127 ),
110+ (torch .float32 , 1 ),
111+ (torch .float64 , 1 ),
112+ (torch .uint16 , 128 ),
113+ (torch .int16 , 128 )
114+ ]
115+ for type , scale in types_to_test :
116+ shape = (23 ,)
117+ tensor = (torch .rand (shape )* scale ).to (type )
118+ fast_tensor = fast .Tensor .createFromTensor (tensor )
119+ assert fast_tensor .getShape ().getDimensions () == len (shape )
120+ assert fast_tensor .getShape ().getAll ()[0 ] == shape [0 ]
121+ assert torch .equal (tensor .to (torch .float32 ), torch .tensor (np .array (fast_tensor )))
122+
123+ shape = (23 ,1 ,23 ,67 )
124+ tensor = (torch .rand (shape )* scale ).to (type )
125+ fast_tensor = fast .Tensor .createFromTensor (tensor )
126+ assert fast_tensor .getShape ().getDimensions () == len (shape )
127+ for i in range (len (shape )):
128+ assert fast_tensor .getShape ().getAll ()[i ] == shape [i ]
129+ assert torch .equal (tensor .to (torch .float32 ), torch .tensor (np .array (fast_tensor )))
130+
131+
132+ def test_torch_tensor_fast_tensor_channels_last_conversion ():
133+ shape = (2 , 32 , 40 )
134+ tensor = torch .rand (shape , dtype = torch .float32 )
135+ fast_tensor = fast .Tensor .createFromTensor (tensor , convertToChannelsLast = True )
118136 assert fast_tensor .getShape ().getDimensions () == len (shape )
119- assert fast_tensor .getShape ().getAll ()[0 ] == shape [0 ]
120- assert torch .equal (tensor .to (torch .float32 ), torch .tensor (np .array (fast_tensor )))
137+ assert fast_tensor .getShape ().getAll ()[0 ] == shape [1 ]
138+ assert fast_tensor .getShape ().getAll ()[1 ] == shape [2 ]
139+ assert fast_tensor .getShape ().getAll ()[2 ] == shape [0 ]
140+ assert torch .equal (tensor .permute ((1 ,2 ,0 )), torch .tensor (np .array (fast_tensor )))
121141
122- shape = (23 ,1 ,23 ,67 )
123- tensor = (torch .rand (shape )* scale ).to (type )
124- fast_tensor = fast .Tensor .createFromTensor (tensor )
125- assert fast_tensor .getShape ().getDimensions () == len (shape )
126- for i in range (len (shape )):
127- assert fast_tensor .getShape ().getAll ()[i ] == shape [i ]
128- assert torch .equal (tensor .to (torch .float32 ), torch .tensor (np .array (fast_tensor )))
129-
130-
131- def test_torch_tensor_fast_tensor_channels_last_conversion ():
132- shape = (2 , 32 , 40 )
133- tensor = torch .rand (shape , dtype = torch .float32 )
134- fast_tensor = fast .Tensor .createFromTensor (tensor , convertToChannelsLast = True )
135- assert fast_tensor .getShape ().getDimensions () == len (shape )
136- assert fast_tensor .getShape ().getAll ()[0 ] == shape [1 ]
137- assert fast_tensor .getShape ().getAll ()[1 ] == shape [2 ]
138- assert fast_tensor .getShape ().getAll ()[2 ] == shape [0 ]
139- assert torch .equal (tensor .permute ((1 ,2 ,0 )), torch .tensor (np .array (fast_tensor )))
142+ except ImportError :
143+ print ('Torch not installed. Not able to run tests.' )
0 commit comments