@@ -21,16 +21,32 @@ def __init__(
21
21
self ._num_nodes = num_nodes
22
22
self ._requires_grad = requires_grad
23
23
24
- # register begin and end as buffers
25
- if len (begin .shape ) == 1 or begin .shape [0 ] == 1 :
26
- self .register_buffer ("begin" , begin .detach ().view ((1 , - 1 ))) # 1xD
24
+ # if either begin or end only has one point, while the other has a batch
25
+ # then we expand the singular point. End result is that both begin and
26
+ # end should have shape BxD
27
+ batch_begin = 1 if len (begin .shape ) == 1 else begin .shape [0 ]
28
+ batch_end = 1 if len (end .shape ) == 1 else end .shape [0 ]
29
+ if batch_begin == 1 and batch_end == 1 :
30
+ _begin = begin .detach ().view ((1 , - 1 )) # 1xD
31
+ _end = end .detach ().view ((1 , - 1 )) # 1xD
32
+ elif batch_begin == 1 : # batch_end > 1
33
+ _begin = begin .detach ().view ((1 , - 1 )).repeat (batch_end , 1 ) # BxD
34
+ _end = end .detach () # BxD
35
+ elif batch_end == 1 : # batch_begin > 1
36
+ _begin = begin .detach () # BxD
37
+ _end = end .detach ().view ((1 , - 1 )).repeat (batch_begin , 1 ) # BxD
38
+ elif batch_begin == batch_end :
39
+ _begin = begin .detach () # BxD
40
+ _end = end .detach () # BxD
27
41
else :
28
- self .register_buffer ("begin" , begin .detach ()) # BxD
42
+ raise ValueError (
43
+ "BasicCurve.__init__ requires begin and end points to have "
44
+ "the same shape"
45
+ )
29
46
30
- if len (end .shape ) == 1 or end .shape [0 ] == 1 :
31
- self .register_buffer ("end" , end .detach ().view ((1 , - 1 ))) # 1xD
32
- else :
33
- self .register_buffer ("end" , end .detach ()) # BxD
47
+ # register begin and end as buffers
48
+ self .register_buffer ("begin" , _begin ) # BxD
49
+ self .register_buffer ("end" , _end ) # BxD
34
50
35
51
# overriden by child modules
36
52
self ._init_params (* args , ** kwargs )
0 commit comments