Skip to content

Commit 1f3d0fe

Browse files
upsample opt, execute pixel shuffle in torch
1 parent 8263b0e commit 1f3d0fe

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

models/experimental/SSR/tt/upsample.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ttnn
22
import math
33
from models.common.lightweightmodule import LightweightModule
4+
import torch
45

56

67
class TTUpsample(LightweightModule):
@@ -75,12 +76,33 @@ def pixel_shuffle(self, x, upscale_factor):
7576

7677
return x
7778

79+
def pixel_shuffle_torch(self, x, upscale_factor):
80+
"""Implement PixelShuffle operation using PyTorch for better performance"""
81+
# PyTorch pixel_shuffle expects NCHW format, but our tensor is NHWC
82+
# Convert from NHWC to NCHW
83+
torch_tensor = x.permute(0, 3, 1, 2)
84+
85+
# Apply PyTorch pixel shuffle
86+
torch_output = torch.nn.functional.pixel_shuffle(torch_tensor, upscale_factor)
87+
88+
# Convert back from NCHW to NHWC
89+
torch_output = torch_output.permute(0, 2, 3, 1)
90+
91+
# Convert back to TTNN tensor
92+
ttnn_output = ttnn.from_torch(
93+
torch_output,
94+
device=self.device,
95+
dtype=ttnn.bfloat16,
96+
layout=ttnn.ROW_MAJOR_LAYOUT,
97+
memory_config=self.memory_config,
98+
)
99+
100+
return ttnn_output
101+
78102
def forward(self, x, parameters):
79103
current = x
80-
current_channels = self.num_feat # Start with 4 channels
81-
slice_config = ttnn.Conv2dSliceConfig(
82-
slice_type=ttnn.Conv2dSliceHeight, num_slices=4 # Adjust based on memory constraints
83-
)
104+
current_channels = self.num_feat
105+
slice_config = ttnn.Conv2dSliceConfig(slice_type=ttnn.Conv2dSliceHeight, num_slices=4)
84106
for i in range(self.num_ops):
85107
# Calculate output channels for this specific convolution
86108
out_channels = current_channels * (self.scale_factor * self.scale_factor)
@@ -91,8 +113,8 @@ def forward(self, x, parameters):
91113
input_tensor=current,
92114
weight_tensor=parameters[f"conv_{i}"]["weight"],
93115
bias_tensor=parameters[f"conv_{i}"]["bias"] if parameters[f"conv_{i}"]["bias"] else None,
94-
in_channels=current_channels, # Use dynamic channel count
95-
out_channels=out_channels, # Use calculated output channels
116+
in_channels=current_channels,
117+
out_channels=out_channels,
96118
device=self.device,
97119
kernel_size=(3, 3),
98120
stride=(1, 1),
@@ -103,25 +125,20 @@ def forward(self, x, parameters):
103125
conv_config=self.conv_config,
104126
compute_config=self.compute_config,
105127
dtype=ttnn.bfloat16,
106-
return_output_dim=False, # Only return the output tensor for simplest call
128+
return_output_dim=False,
107129
return_weights_and_bias=False,
108130
slice_config=slice_config,
109131
)
110132

133+
current = ttnn.to_torch(current)
111134
# reshape B,1,H*W, C to B, H, W, C
112-
current = ttnn.reshape(
113-
current,
114-
(
115-
batch_size,
116-
current.shape[2] // (height * batch_size),
117-
current.shape[2] // (height * batch_size),
118-
out_channels,
119-
),
120-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
135+
current = current.reshape(
136+
batch_size,
137+
current.shape[2] // (height * batch_size),
138+
current.shape[2] // (height * batch_size),
139+
out_channels,
121140
)
122-
# Apply pixel shuffle
123-
current = self.pixel_shuffle(current, self.scale_factor)
124-
141+
current = self.pixel_shuffle_torch(current, self.scale_factor)
125142
# After pixel shuffle, channels return to original count
126143
current_channels = self.num_feat
127144

0 commit comments

Comments
 (0)