@@ -77,10 +77,71 @@ void GatedDeltaNet::validate_and_infer_types() {
7777 input_check (this , 4 , " gate" , {3 }, {});
7878 input_check (this , 5 , " beta" , {3 }, {});
7979
80- // value head_size may be not same with key
80+ // batch, seq_len, head_num, head_size
81+ const auto & query_ps = get_input_partial_shape (0 );
82+ const auto & key_ps = get_input_partial_shape (1 );
83+ const auto & value_ps = get_input_partial_shape (2 );
84+ const auto & state_ps = get_input_partial_shape (3 );
85+ const auto & gate_ps = get_input_partial_shape (4 );
86+ const auto & beta_ps = get_input_partial_shape (5 );
87+
88+ const auto q_head_num = query_ps[2 ];
89+ const auto k_head_num = key_ps[2 ];
90+ const auto v_head_num = value_ps[2 ];
91+
92+ const auto k_head_size = key_ps[3 ];
93+ const auto v_head_size = value_ps[3 ];
94+
95+ NODE_VALIDATION_CHECK (this , q_head_num.is_static () && k_head_num.is_static () && q_head_num.get_length () == k_head_num.get_length (),
96+ " The number of heads in query and key should be the same, but got " ,
97+ q_head_num,
98+ " and " ,
99+ k_head_num,
100+ " ." );
101+
102+ NODE_VALIDATION_CHECK (this , k_head_size.is_static () && v_head_size.is_static () && k_head_size.get_length () == v_head_size.get_length (),
103+ " The head size in key and value should be the same, but got " ,
104+ k_head_size,
105+ " and " ,
106+ v_head_size,
107+ " ." );
108+
109+ const auto gate_head_num = gate_ps[2 ];
110+ const auto beta_head_num = beta_ps[2 ];
111+
112+ NODE_VALIDATION_CHECK (this , gate_head_num.is_static () && beta_head_num.is_static () && gate_head_num.get_length () == beta_head_num.get_length (),
113+ " The number of heads in gate and beta should be the same, but got " ,
114+ gate_head_num,
115+ " and " ,
116+ beta_head_num,
117+ " ." );
118+
119+ // [batch, v_head_nums, v_head_size, k_head_size]
120+ const auto state_head_num = state_ps[1 ];
121+ const auto state_hidden_size_0 = state_ps[2 ];
122+ const auto state_hidden_size_1 = state_ps[3 ];
123+ NODE_VALIDATION_CHECK (this , state_head_num.is_static () && state_head_num.get_length () == v_head_num.get_length (),
124+ " The number of heads in recurrent_state and value should be the same, but got " ,
125+ state_head_num,
126+ " and " ,
127+ v_head_num,
128+ " ." );
129+ NODE_VALIDATION_CHECK (this , state_hidden_size_0.is_static () && state_hidden_size_0.get_length () == v_head_size.get_length (),
130+ " The [-2] dim in shape of recurrent_state and value should be the same, but got " ,
131+ state_hidden_size_0,
132+ " and " ,
133+ v_head_size,
134+ " ." );
135+ NODE_VALIDATION_CHECK (this , state_hidden_size_1.is_static () && state_hidden_size_1.get_length () == k_head_size.get_length (),
136+ " The [-1] dim in shape of recurrent_state and key should be the same, but got " ,
137+ state_hidden_size_1,
138+ " and " ,
139+ k_head_size,
140+ " ." );
141+ // output has the same shape and type as input value, output state has the same shape and type as input recurrent_state
81142 auto out_ps = get_input_partial_shape (2 );
82143 const auto & h_ps = get_input_partial_shape (3 );
83- set_output_type (0 , get_input_element_type (0 ), out_ps);
144+ set_output_type (0 , get_input_element_type (2 ), out_ps);
84145 set_output_type (1 , get_input_element_type (3 ), h_ps);
85146}
86147
0 commit comments