@@ -233,34 +233,28 @@ tl::expected<TensorShape, std::string>
233233 };
234234}
235235
236- tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic< TensorShape> >, std::string>
236+ tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
237237 get_weight_shapes (MultiHeadAttentionAttrs const &attrs,
238238 TensorShape const &input_q,
239239 TensorShape const &input_k,
240240 TensorShape const &input_v) {
241241
242- std::unordered_map<TensorSlotName, SingularOrVariadic< TensorShape> > weight_shapes = {
242+ std::unordered_map<TensorSlotName, TensorShape> weight_shapes = {
243243 {
244244 TensorSlotName::WEIGHT,
245- SingularOrVariadic<TensorShape>{
246- PROPAGATE_ERR (get_weights_shape (attrs, input_q, input_k, input_v)),
247- },
245+ PROPAGATE_ERR (get_weights_shape (attrs, input_q, input_k, input_v)),
248246 },
249247 };
250248
251249 if (attrs.bias ) {
252250 weight_shapes.insert ({
253251 TensorSlotName::INPUT_BIAS,
254- SingularOrVariadic<TensorShape>{
255- PROPAGATE_ERR (get_input_bias_shape (attrs, input_q, input_k, input_v)),
256- },
252+ PROPAGATE_ERR (get_input_bias_shape (attrs, input_q, input_k, input_v)),
257253 });
258254
259255 weight_shapes.insert ({
260256 TensorSlotName::OUTPUT_BIAS,
261- SingularOrVariadic<TensorShape>{
262- PROPAGATE_ERR (get_output_bias_shape (attrs, input_q, input_k, input_v)),
263- },
257+ PROPAGATE_ERR (get_output_bias_shape (attrs, input_q, input_k, input_v)),
264258 });
265259 }
266260
@@ -422,34 +416,28 @@ positive_int get_oSize(TensorShape const &) {
422416 NOT_IMPLEMENTED ();
423417}
424418
425- tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic< ParallelTensorShape> >, std::string>
419+ tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
426420 get_weight_shapes (MultiHeadAttentionAttrs const &attrs,
427421 ParallelTensorShape const &input_q,
428422 ParallelTensorShape const &input_k,
429423 ParallelTensorShape const &input_v) {
430424
431- std::unordered_map<TensorSlotName, SingularOrVariadic< ParallelTensorShape> > weight_shapes = {
425+ std::unordered_map<TensorSlotName, ParallelTensorShape> weight_shapes = {
432426 {
433427 TensorSlotName::WEIGHT,
434- SingularOrVariadic<ParallelTensorShape>{
435- PROPAGATE_ERR (get_weights_shape (attrs, input_q, input_k, input_v)),
436- },
428+ PROPAGATE_ERR (get_weights_shape (attrs, input_q, input_k, input_v)),
437429 },
438430 };
439431
440432 if (attrs.bias ) {
441433 weight_shapes.insert ({
442434 TensorSlotName::INPUT_BIAS,
443- SingularOrVariadic<ParallelTensorShape>{
444- PROPAGATE_ERR (get_input_bias_shape (attrs, input_q, input_k, input_v)),
445- },
435+ PROPAGATE_ERR (get_input_bias_shape (attrs, input_q, input_k, input_v)),
446436 });
447437
448438 weight_shapes.insert ({
449439 TensorSlotName::OUTPUT_BIAS,
450- SingularOrVariadic<ParallelTensorShape>{
451- PROPAGATE_ERR (get_output_bias_shape (attrs, input_q, input_k, input_v)),
452- },
440+ PROPAGATE_ERR (get_output_bias_shape (attrs, input_q, input_k, input_v)),
453441 });
454442 }
455443
0 commit comments