Skip to content

Commit 009c4dc

Browse files
committed
update python file for amd bug (and params for exp1)
1 parent e99c858 commit 009c4dc

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

diag/conv1d.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010

1111
# Parameters
1212

13-
channel_in = 3
14-
channel_out = channel_in
13+
batch_size=16384
1514
length = 512
15+
k = 3
16+
channel_in = 1
17+
channel_out = channel_in
1618

17-
n0 = 16384
19+
n0 = batch_size
1820
n1 = length*channel_out
1921
n2 = 1
20-
k = 2
2122

22-
batch_size=n0*n2
2323
# i0, i1, i2 = torch.meshgrid(
2424
# torch.arange(batch_size, device=device, dtype=real_t),
2525
# torch.arange(channel_in, device=device, dtype=real_t),
@@ -61,6 +61,7 @@ def sum_and_normalize(data):
6161
output = conv1d_scripted(data) # Use pre-compiled model
6262
torch.cuda.synchronize()
6363
end.record()
64+
torch.cuda.synchronize()
6465
elapsed_time = start.elapsed_time(end)/1000
6566

6667
# Compute final sum and normalization

0 commit comments

Comments
 (0)