1+ import numpy as np
2+ import TensorFrost as tf
3+ import matplotlib .pyplot as plt
4+ import os
5+ current_folder = os .path .dirname (os .path .abspath (__file__ ))
6+
7+ tf .initialize (tf .opengl )
8+
9+ def get_except_axis (a , axis ):
10+ indices = list (a )
11+ indices .pop (axis ) # Remove the axis dimension from the indices
12+ return tuple (indices )
13+
14+ def get_with_axis (indices , axis_index , axis ):
15+ # Create a new tuple of indices with the axis index inserted at the specified axis
16+ new_indices = list (indices )
17+ new_indices .insert (axis , axis_index )
18+ return tuple (new_indices )
19+
20+ def intlog2 (x ):
21+ # Calculate the integer logarithm base 2 of x
22+ if x < 1 :
23+ raise ValueError ("x must be a positive integer" )
24+ log = 0
25+ while x > 1 :
26+ x >>= 1
27+ log += 1
28+ return log
29+
30+ def inplace_fft (tensor , axis = - 1 , inverse = False ):
31+ shape = tensor .shape [:- 1 ] #shape without the last dimension (complex number)
32+ N = shape [axis ].try_get_constant ()
33+ print ("N:" , N )
34+ if N == None or N & (N - 1 ) != 0 :
35+ raise ValueError ("FFT only supports constant power of 2 sizes" )
36+ BK = min (256 , N // 2 ) #Group size
37+ RADIX2 = intlog2 (N )
38+
39+ def expi (angle ):
40+ return tf .cos (angle ), tf .sin (angle )
41+
42+ def cmul (a , b ):
43+ return a [0 ] * b [0 ] - a [1 ] * b [1 ], a [0 ] * b [1 ] + a [1 ] * b [0 ]
44+
45+ def radix2 (temp , span , index , inverse ):
46+ group_half_mask = span - 1
47+ group_offset = index & group_half_mask
48+ group_index = index - group_offset
49+ k1 = (group_index << 1 ) + group_offset
50+ k2 = k1 + span
51+
52+ d = 1.0 if inverse else - 1.0
53+ angle = 2 * np .pi * d * tf .float (group_offset ) / tf .float (span << 1 )
54+
55+ #radix2 butterfly
56+ v1 = temp [2 * k1 ], temp [2 * k1 + 1 ]
57+ v2 = cmul (expi (angle ), (temp [2 * k2 ], temp [2 * k2 + 1 ]))
58+ temp [2 * k1 ] = v1 [0 ] + v2 [0 ]
59+ temp [2 * k1 + 1 ] = v1 [1 ] + v2 [1 ]
60+ temp [2 * k2 ] = v1 [0 ] - v2 [0 ]
61+ temp [2 * k2 + 1 ] = v1 [1 ] - v2 [1 ]
62+
63+ #workgroup mapped to our axis
64+ new_shape = get_except_axis (shape , axis ) + (BK ,)
65+ with tf .kernel (list (new_shape ), group_size = [BK ]) as indices :
66+ temp = tf .group_buffer (N * 2 , tf .float32 )
67+ if not isinstance (indices , tuple ):
68+ indices = (indices ,)
69+
70+ tx = indices [0 ].block_thread_index (0 )
71+ indices = get_except_axis (indices , - 1 ) #skip the workgroup axis
72+ M = N // BK
73+
74+ for i in range (M ):
75+ rowIndex = i * BK + tx
76+ idx = tf .int (tf .reversebits (tf .uint (rowIndex )) >> (32 - RADIX2 ))
77+ tensor_index = get_with_axis (indices , rowIndex , axis )
78+ temp [2 * idx ] = tensor [tensor_index + (0 ,)]
79+ temp [2 * idx + 1 ] = tensor [tensor_index + (1 ,)]
80+
81+ tf .group_barrier ()
82+
83+ span = 1
84+ while span < N :
85+ for j in range (M // 2 ):
86+ rowIndex = j * BK + tx
87+ radix2 (temp , span , rowIndex , inverse )
88+ tf .group_barrier ()
89+ span *= 2
90+
91+ for i in range (M ):
92+ rowIndex = i * BK + tx
93+ factor = 1.0 / float (N if inverse else 1 )
94+ tensor_index = get_with_axis (indices , rowIndex , axis )
95+ tensor [tensor_index + (0 ,)] = temp [2 * rowIndex ] * factor
96+ tensor [tensor_index + (1 ,)] = temp [2 * rowIndex + 1 ] * factor
97+
98+ target_res = 1024
99+
100+ def fft ():
101+ A = tf .input ([target_res , target_res , 2 ], tf .float32 )
102+ B = tf .copy (A )
103+ inplace_fft (B , axis = 0 , inverse = False )
104+ inplace_fft (B , axis = 1 , inverse = False )
105+ return B
106+
107+ fft = tf .compile (fft )
108+
109+ all_kernels = tf .get_all_generated_kernels ()
110+ print ("Generated kernels:" )
111+ for k in all_kernels :
112+ print (k [0 ][2 ])
113+
114+ input_img = np .array (plt .imread (current_folder + "/test.png" ), dtype = np .float32 )
115+ image_resampled = np .pad (input_img , ((0 , target_res - input_img .shape [0 ]), (0 , target_res - input_img .shape [1 ]), (0 , 0 )), 'constant' )
116+
117+ plt .imshow (image_resampled )
118+ plt .show ()
119+ print (image_resampled .shape )
120+
121+ r_channel = image_resampled [..., 0 ]
122+ complex_image = np .zeros ((target_res , target_res , 2 ), dtype = np .float32 )
123+ complex_image [..., 0 ] = r_channel
124+ complex_image [..., 1 ] = np .zeros ((target_res , target_res ), dtype = np .float32 )
125+
126+ image_tf = tf .tensor (complex_image )
127+ transformed = fft (image_tf )
128+ transformed = transformed .numpy
129+
130+ #plot the magnitude of the transformed image
131+ plt .imshow (np .log (np .abs (transformed [..., 0 ] + 1j * transformed [..., 1 ])))
132+ plt .colorbar ()
133+ plt .title ("FFT" )
134+ plt .show ()
0 commit comments