11import ttnn
22import math
33from models .common .lightweightmodule import LightweightModule
4+ import torch
45
56
67class TTUpsample (LightweightModule ):
@@ -75,12 +76,33 @@ def pixel_shuffle(self, x, upscale_factor):
7576
7677 return x
7778
79+ def pixel_shuffle_torch (self , x , upscale_factor ):
80+ """Implement PixelShuffle operation using PyTorch for better performance"""
81+ # PyTorch pixel_shuffle expects NCHW format, but our tensor is NHWC
82+ # Convert from NHWC to NCHW
83+ torch_tensor = x .permute (0 , 3 , 1 , 2 )
84+
85+ # Apply PyTorch pixel shuffle
86+ torch_output = torch .nn .functional .pixel_shuffle (torch_tensor , upscale_factor )
87+
88+ # Convert back from NCHW to NHWC
89+ torch_output = torch_output .permute (0 , 2 , 3 , 1 )
90+
91+ # Convert back to TTNN tensor
92+ ttnn_output = ttnn .from_torch (
93+ torch_output ,
94+ device = self .device ,
95+ dtype = ttnn .bfloat16 ,
96+ layout = ttnn .ROW_MAJOR_LAYOUT ,
97+ memory_config = self .memory_config ,
98+ )
99+
100+ return ttnn_output
101+
78102 def forward (self , x , parameters ):
79103 current = x
80- current_channels = self .num_feat # Start with 4 channels
81- slice_config = ttnn .Conv2dSliceConfig (
82- slice_type = ttnn .Conv2dSliceHeight , num_slices = 4 # Adjust based on memory constraints
83- )
104+ current_channels = self .num_feat
105+ slice_config = ttnn .Conv2dSliceConfig (slice_type = ttnn .Conv2dSliceHeight , num_slices = 4 )
84106 for i in range (self .num_ops ):
85107 # Calculate output channels for this specific convolution
86108 out_channels = current_channels * (self .scale_factor * self .scale_factor )
@@ -91,8 +113,8 @@ def forward(self, x, parameters):
91113 input_tensor = current ,
92114 weight_tensor = parameters [f"conv_{ i } " ]["weight" ],
93115 bias_tensor = parameters [f"conv_{ i } " ]["bias" ] if parameters [f"conv_{ i } " ]["bias" ] else None ,
94- in_channels = current_channels , # Use dynamic channel count
95- out_channels = out_channels , # Use calculated output channels
116+ in_channels = current_channels ,
117+ out_channels = out_channels ,
96118 device = self .device ,
97119 kernel_size = (3 , 3 ),
98120 stride = (1 , 1 ),
@@ -103,25 +125,20 @@ def forward(self, x, parameters):
103125 conv_config = self .conv_config ,
104126 compute_config = self .compute_config ,
105127 dtype = ttnn .bfloat16 ,
106- return_output_dim = False , # Only return the output tensor for simplest call
128+ return_output_dim = False ,
107129 return_weights_and_bias = False ,
108130 slice_config = slice_config ,
109131 )
110132
133+ current = ttnn .to_torch (current )
111134 # reshape B,1,H*W, C to B, H, W, C
112- current = ttnn .reshape (
113- current ,
114- (
115- batch_size ,
116- current .shape [2 ] // (height * batch_size ),
117- current .shape [2 ] // (height * batch_size ),
118- out_channels ,
119- ),
120- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
135+ current = current .reshape (
136+ batch_size ,
137+ current .shape [2 ] // (height * batch_size ),
138+ current .shape [2 ] // (height * batch_size ),
139+ out_channels ,
121140 )
122- # Apply pixel shuffle
123- current = self .pixel_shuffle (current , self .scale_factor )
124-
141+ current = self .pixel_shuffle_torch (current , self .scale_factor )
125142 # After pixel shuffle, channels return to original count
126143 current_channels = self .num_feat
127144
0 commit comments