Skip to content

Commit 500f89a

Browse files
committed
update version
1 parent adff979 commit 500f89a

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

Python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
88

99
[project]
1010
name = "TensorFrost"
11-
version = "0.7.3dev1"
11+
version = "0.7.3"
1212
description = "A static optimizing tensor compiler with a Python frontend…"
1313
authors = [{name = "Mykhailo Moroz", email = "michael08840884@gmail.com"}]
1414
requires-python = ">=3.7"

examples/Rendering/fft3d.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import numpy as np
2+
import TensorFrost as tf
3+
import matplotlib.pyplot as plt
4+
import os
5+
current_folder = os.path.dirname(os.path.abspath(__file__))
6+
7+
tf.initialize(tf.opengl)
8+
9+
def get_except_axis(a, axis):
10+
indices = list(a)
11+
indices.pop(axis) # Remove the axis dimension from the indices
12+
return tuple(indices)
13+
14+
def get_with_axis(indices, axis_index, axis):
15+
# Create a new tuple of indices with the axis index inserted at the specified axis
16+
new_indices = list(indices)
17+
new_indices.insert(axis, axis_index)
18+
return tuple(new_indices)
19+
20+
def intlog2(x):
21+
# Calculate the integer logarithm base 2 of x
22+
if x < 1:
23+
raise ValueError("x must be a positive integer")
24+
log = 0
25+
while x > 1:
26+
x >>= 1
27+
log += 1
28+
return log
29+
30+
def inplace_fft(tensor, axis = -1, inverse = False):
31+
shape = tensor.shape[:-1] #shape without the last dimension (complex number)
32+
N = shape[axis].try_get_constant()
33+
print("N:", N)
34+
if N == None or N & (N - 1) != 0:
35+
raise ValueError("FFT only supports constant power of 2 sizes")
36+
BK = min(256, N // 2) #Group size
37+
RADIX2 = intlog2(N)
38+
39+
def expi(angle):
40+
return tf.cos(angle), tf.sin(angle)
41+
42+
def cmul(a, b):
43+
return a[0] * b[0] - a[1] * b[1], a[0] * b[1] + a[1] * b[0]
44+
45+
def radix2(temp, span, index, inverse):
46+
group_half_mask = span - 1
47+
group_offset = index & group_half_mask
48+
group_index = index - group_offset
49+
k1 = (group_index << 1) + group_offset
50+
k2 = k1 + span
51+
52+
d = 1.0 if inverse else -1.0
53+
angle = 2 * np.pi * d * tf.float(group_offset) / tf.float(span << 1)
54+
55+
#radix2 butterfly
56+
v1 = temp[2*k1], temp[2*k1 + 1]
57+
v2 = cmul(expi(angle), (temp[2*k2], temp[2*k2 + 1]))
58+
temp[2*k1] = v1[0] + v2[0]
59+
temp[2*k1 + 1] = v1[1] + v2[1]
60+
temp[2*k2] = v1[0] - v2[0]
61+
temp[2*k2 + 1] = v1[1] - v2[1]
62+
63+
#workgroup mapped to our axis
64+
new_shape = get_except_axis(shape, axis) + (BK,)
65+
with tf.kernel(list(new_shape), group_size=[BK]) as indices:
66+
temp = tf.group_buffer(N*2, tf.float32)
67+
if not isinstance(indices, tuple):
68+
indices = (indices,)
69+
70+
tx = indices[0].block_thread_index(0)
71+
indices = get_except_axis(indices, -1) #skip the workgroup axis
72+
M = N // BK
73+
74+
for i in range(M):
75+
rowIndex = i * BK + tx
76+
idx = tf.int(tf.reversebits(tf.uint(rowIndex)) >> (32 - RADIX2))
77+
tensor_index = get_with_axis(indices, rowIndex, axis)
78+
temp[2*idx] = tensor[tensor_index + (0,)]
79+
temp[2*idx + 1] = tensor[tensor_index + (1,)]
80+
81+
tf.group_barrier()
82+
83+
span = 1
84+
while span < N:
85+
for j in range(M // 2):
86+
rowIndex = j * BK + tx
87+
radix2(temp, span, rowIndex, inverse)
88+
tf.group_barrier()
89+
span *= 2
90+
91+
for i in range(M):
92+
rowIndex = i * BK + tx
93+
factor = 1.0 / float(N if inverse else 1)
94+
tensor_index = get_with_axis(indices, rowIndex, axis)
95+
tensor[tensor_index + (0,)] = temp[2*rowIndex] * factor
96+
tensor[tensor_index + (1,)] = temp[2*rowIndex + 1] * factor
97+
98+
target_res = 1024
99+
100+
def fft():
101+
A = tf.input([target_res, target_res, 2], tf.float32)
102+
B = tf.copy(A)
103+
inplace_fft(B, axis=0, inverse=False)
104+
inplace_fft(B, axis=1, inverse=False)
105+
return B
106+
107+
fft = tf.compile(fft)
108+
109+
all_kernels = tf.get_all_generated_kernels()
110+
print("Generated kernels:")
111+
for k in all_kernels:
112+
print(k[0][2])
113+
114+
input_img = np.array(plt.imread(current_folder+"/test.png"), dtype=np.float32)
115+
image_resampled = np.pad(input_img, ((0, target_res - input_img.shape[0]), (0, target_res - input_img.shape[1]), (0, 0)), 'constant')
116+
117+
plt.imshow(image_resampled)
118+
plt.show()
119+
print(image_resampled.shape)
120+
121+
r_channel = image_resampled[..., 0]
122+
complex_image = np.zeros((target_res, target_res, 2), dtype=np.float32)
123+
complex_image[..., 0] = r_channel
124+
complex_image[..., 1] = np.zeros((target_res, target_res), dtype=np.float32)
125+
126+
image_tf = tf.tensor(complex_image)
127+
transformed = fft(image_tf)
128+
transformed = transformed.numpy
129+
130+
#plot the magnitude of the transformed image
131+
plt.imshow(np.log(np.abs(transformed[..., 0] + 1j * transformed[..., 1])))
132+
plt.colorbar()
133+
plt.title("FFT")
134+
plt.show()

0 commit comments

Comments
 (0)