@@ -353,6 +353,7 @@ def _filter(self,
353353 self .filter_precision ,
354354 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
355355 trainable = trainable )
356+ hidden = tf .reshape (activation_fn (tf .matmul (xyz_scatter , w ) + b ), [- 1 , outputs_size [ii ]])
356357 if self .filter_resnet_dt :
357358 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
358359 [1 , outputs_size [ii ]],
@@ -361,16 +362,16 @@ def _filter(self,
361362 trainable = trainable )
362363 if outputs_size [ii ] == outputs_size [ii - 1 ]:
363364 if self .filter_resnet_dt :
364- xyz_scatter += activation_fn ( tf . matmul ( xyz_scatter , w ) + b ) * idt
365+ xyz_scatter += hidden * idt
365366 else :
366- xyz_scatter += activation_fn ( tf . matmul ( xyz_scatter , w ) + b )
367+ xyz_scatter += hidden
367368 elif outputs_size [ii ] == outputs_size [ii - 1 ] * 2 :
368369 if self .filter_resnet_dt :
369- xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn ( tf . matmul ( xyz_scatter , w ) + b ) * idt
370+ xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + hidden * idt
370371 else :
371- xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn ( tf . matmul ( xyz_scatter , w ) + b )
372+ xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + hidden
372373 else :
373- xyz_scatter = activation_fn ( tf . matmul ( xyz_scatter , w ) + b )
374+ xyz_scatter = hidden
374375 else :
375376 w = tf .zeros ((outputs_size [0 ], outputs_size [- 1 ]), dtype = global_tf_float_precision )
376377 xyz_scatter = tf .matmul (xyz_scatter , w )
@@ -440,6 +441,7 @@ def _filter_type_ext(self,
440441 self .filter_precision ,
441442 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
442443 trainable = trainable )
444+ hidden = tf .reshape (activation_fn (tf .matmul (xyz_scatter , w ) + b ), [- 1 , outputs_size [ii ]])
443445 if self .filter_resnet_dt :
444446 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
445447 [1 , outputs_size [ii ]],
@@ -448,16 +450,16 @@ def _filter_type_ext(self,
448450 trainable = trainable )
449451 if outputs_size [ii ] == outputs_size [ii - 1 ]:
450452 if self .filter_resnet_dt :
451- xyz_scatter += activation_fn ( tf . matmul ( xyz_scatter , w ) + b ) * idt
453+ xyz_scatter += hidden * idt
452454 else :
453- xyz_scatter += activation_fn ( tf . matmul ( xyz_scatter , w ) + b )
455+ xyz_scatter += hidden
454456 elif outputs_size [ii ] == outputs_size [ii - 1 ] * 2 :
455457 if self .filter_resnet_dt :
456- xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn ( tf . matmul ( xyz_scatter , w ) + b ) * idt
458+ xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + hidden * idt
457459 else :
458- xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn ( tf . matmul ( xyz_scatter , w ) + b )
460+ xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + hidden
459461 else :
460- xyz_scatter = activation_fn ( tf . matmul ( xyz_scatter , w ) + b )
462+ xyz_scatter = hidden
461463 # natom x nei_type_i x out_size
462464 xyz_scatter = tf .reshape (xyz_scatter , (- 1 , shape_i [1 ]// 4 , outputs_size [- 1 ]))
463465 # natom x nei_type_i x 4
0 commit comments