File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -209,6 +209,10 @@ struct ReduceConfig {
209209 int values_per_thread () const {
210210 return div_up (num_inputs, step_input);
211211 }
212+
213+ int mock_values_per_thread (int parallelism) {
214+ return div_up (num_inputs, step_input * parallelism);
215+ }
212216};
213217
214218std::ostream& operator <<(std::ostream& out, const ReduceConfig& config);
@@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
11661170 else if (config.ctas_per_output < 16 )
11671171 config.ctas_per_output = 1 ;
11681172 bool is_channel_last = iter.tensor_base (1 ).is_contiguous (at::MemoryFormat::ChannelsLast);
1169- if (iter.ndim () == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
1173+ if (iter.ndim () == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
11701174 config.ctas_per_output = 4 ;
1175+ int vpt = config.values_per_thread ();
1176+ // Capping the number of values per thread to 2048 for now
1177+ // based on known use cases.
1178+ while (vpt >= 2048 ) {
1179+ config.ctas_per_output *= 2 ;
1180+ // Computes the new values per thread without side effects
1181+ vpt = config.mock_values_per_thread (config.ctas_per_output );
1182+ }
1183+ }
11711184#endif
11721185 if (config.ctas_per_output > 1 ) {
11731186 config.input_mult [2 ] = config.split_input (config.ctas_per_output );
You can’t perform that action at this time.
0 commit comments