Skip to content

Commit 0d476ef

Browse files
committed
wip: new op
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 66924f7 commit 0d476ef

33 files changed

+1551
-2889
lines changed

docs/guides/operators.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import jax.numpy as jnp
2020
Galilean operators represent the basic transformations in classical mechanics:
2121
translations, rotations, and boosts.
2222

23-
### GalileanSpatialTranslation
23+
### GalileanOp
2424

2525
Translates position vectors by a fixed offset:
2626

@@ -31,31 +31,31 @@ Translates position vectors by a fixed offset:
3131

3232
```{code-block} python
3333
>>> q = cxv.CartesianPos3D.from_([1, 2, 3], "kpc")
34-
>>> op = cxo.GalileanSpatialTranslation.from_([10, 10, 10], "kpc")
34+
>>> op = cxo.GalileanOp.from_([10, 10, 10], "kpc")
3535
>>> op(q)
3636
CartesianPos3D(
3737
x=Quantity(11, unit='kpc'), y=Quantity(12, unit='kpc'), z=Quantity(13, unit='kpc')
3838
)
3939
```
4040

41-
### GalileanBoost
41+
### Galilean Boost
4242

4343
Applies a velocity boost to a velocity vector:
4444

4545
```{code-block} python
46-
>>> boost = cxo.GalileanBoost.from_([1, 1, 1], "km/s")
46+
>>> boost = cxo.Add.from_([1, 1, 1], "km/s")
4747
>>> boost(u.Quantity(1.0, "s"), q)[1]
4848
CartesianPos3D(
4949
x=Quantity(1., unit='kpc'), y=Quantity(2., unit='kpc'), z=Quantity(3., unit='kpc')
5050
)
5151
```
5252

53-
### GalileanRotation
53+
### Rotate
5454

5555
Rotates vectors in space:
5656

5757
```{code-block} python
58-
>>> rot = cxo.GalileanRotation.from_euler("z", u.Quantity(90, "deg"))
58+
>>> rot = cxo.Rotate.from_euler("z", u.Quantity(90, "deg"))
5959
>>> rot(q).round(2)
6060
CartesianPos3D(
6161
x=Quantity(-2., unit='kpc'),
@@ -72,8 +72,8 @@ Operators can be composed using the {class}`~coordinax.ops.Pipe` class or the
7272
`|` operator:
7373

7474
```{code-block} python
75-
>>> op1 = cxo.GalileanSpatialTranslation.from_([1, 0, 0], "kpc")
76-
>>> op2 = cxo.GalileanRotation.from_euler("z", u.Quantity(90, "deg"))
75+
>>> op1 = cxo.GalileanOp.from_([1, 0, 0], "kpc")
76+
>>> op2 = cxo.Rotate.from_euler("z", u.Quantity(90, "deg"))
7777
>>> pipe = cxo.Pipe([op1, op2])
7878
>>> pipe(q).round(2)
7979
CartesianPos3D(
@@ -97,17 +97,14 @@ CartesianPos3D(
9797

9898
- {class}`~coordinax.ops.Identity`: The do-nothing operator, useful for generic
9999
code.
100-
- {class}`~coordinax.ops.AbstractCompositeOperator`: Base for building custom
100+
- {class}`~coordinax.ops.Pipe`: Base for building custom
101101
operator pipelines.
102102

103103
---
104104

105105
## Utilities and Advanced Usage
106106

107-
- {class}`~coordinax.ops.simplify_op`: Simplifies composed operators when
108-
possible.
109-
- {class}`~coordinax.ops.convert_to_pipe_operators`: Utility to convert a list
110-
of operators into a {class}`~coordinax.ops.Pipe`.
107+
- {class}`~coordinax.ops.simplify`: Simplifies composed operators when possible.
111108

112109
---
113110

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ set of vector operations that work seamlessly with all `coordinax` vector types.
315315
```{code-block} python
316316
>>> import coordinax.ops as cxo
317317
318-
>>> op = cxo.GalileanSpatialTranslation.from_([10, 10, 10], "kpc")
318+
>>> op = cxo.GalileanOp.from_([10, 10, 10], "kpc")
319319
320320
>>> print(op(q))
321321
<CartesianPos3D: (x, y, z) [kpc]

src/coordinax/_coordinax_space_frames/frame_transforms.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,7 @@
1616
from .galactocentric import Galactocentric
1717
from .icrs import ICRS
1818
from coordinax._src.distances import Distance
19-
from coordinax._src.operators import (
20-
GalileanRotation,
21-
GalileanSpatialTranslation,
22-
Identity,
23-
Pipe,
24-
VelocityBoost,
25-
simplify_op,
26-
)
19+
from coordinax._src.operators import Add, Identity, Pipe, Rotate, simplify
2720

2821
ScalarAngle: TypeAlias = Shaped[u.Quantity["angle"] | u.Angle, ""]
2922
RotationMatrix: TypeAlias = Shaped[Array, "3 3"]
@@ -54,7 +47,7 @@ def frame_transform_op(
5447
5548
>>> @dispatch
5649
... def frame_transform_op(from_frame: MySpaceFrame, to_frame: ICRS, /) -> cx.ops.AbstractOperator:
57-
... return cx.ops.GalileanRotation.from_euler("z", u.Quantity(10, "deg"))
50+
... return cx.ops.Rotate.from_euler("z", u.Quantity(10, "deg"))
5851
5952
We can transform from `MySpaceFrame` to a Galacocentric frame, even though
6053
we don't have a direct transformation defined:
@@ -65,15 +58,15 @@ def frame_transform_op(
6558
>>> op = cx.frames.frame_transform_op(my_frame, gcf_frame)
6659
>>> op
6760
Pipe((
68-
GalileanRotation(rotation=f32[3,3]),
61+
Rotate(rotation=f32[3,3]),
6962
...
7063
))
7164
7265
""" # noqa: E501
7366
fromframe_to_icrs = frame_transform_op(from_frame, _icrs_frame)
7467
icrs_to_toframe = frame_transform_op(_icrs_frame, to_frame)
7568
pipe = fromframe_to_icrs | icrs_to_toframe
76-
return simplify_op(pipe)
69+
return simplify(pipe)
7770

7871

7972
# ---------------------------------------------------------------
@@ -116,22 +109,22 @@ def frame_transform_op(from_frame: Galactocentric, to_frame: Galactocentric, /)
116109
>>> frame_op2 = cxf.frame_transform_op(gcf_frame, gcf_frame2)
117110
>>> frame_op2
118111
Pipe((
119-
VelocityBoost(CartesianVel3D( ... )),
120-
GalileanRotation(rotation=f32[3,3]),
121-
GalileanSpatialTranslation(CartesianPos3D( ... )),
122-
GalileanRotation(rotation=f32[3,3]),
123-
GalileanRotation(rotation=f32[3,3]),
124-
GalileanSpatialTranslation(CartesianPos3D( ... )),
125-
GalileanRotation(rotation=f32[3,3]),
126-
VelocityBoost(CartesianVel3D( ... ))
112+
Add(CartesianVel3D( ... )),
113+
Rotate(rotation=f32[3,3]),
114+
Add(CartesianPos3D( ... )),
115+
Rotate(rotation=f32[3,3]),
116+
Rotate(rotation=f32[3,3]),
117+
Add(CartesianPos3D( ... )),
118+
Rotate(rotation=f32[3,3]),
119+
Add(CartesianVel3D( ... ))
127120
))
128121
129122
"""
130123
if from_frame == to_frame:
131124
return Pipe((Identity(),))
132125

133126
# TODO: not go through ICRS for the self-transformation
134-
return simplify_op(
127+
return simplify(
135128
frame_transform_op(from_frame, ICRS()) | frame_transform_op(ICRS(), to_frame)
136129
)
137130

@@ -193,10 +186,10 @@ def frame_transform_op(from_frame: ICRS, to_frame: Galactocentric, /) -> Pipe:
193186
>>> frame_op = cx.frames.frame_transform_op(icrs_frame, gcf_frame)
194187
>>> frame_op
195188
Pipe((
196-
GalileanRotation(rotation=f32[3,3]),
197-
GalileanSpatialTranslation(CartesianPos3D( ... )),
198-
GalileanRotation(rotation=f32[3,3]),
199-
VelocityBoost(CartesianVel3D( ... ))
189+
Rotate(rotation=f32[3,3]),
190+
Add(CartesianPos3D( ... )),
191+
Rotate(rotation=f32[3,3]),
192+
Add(CartesianVel3D( ... ))
200193
))
201194
202195
Apply the transformation:
@@ -238,25 +231,23 @@ def frame_transform_op(from_frame: ICRS, to_frame: Galactocentric, /) -> Pipe:
238231
239232
""" # noqa: E501
240233
# rotation matrix to align x(ICRS) with the vector to the Galactic center
241-
rot_lat = GalileanRotation.from_euler("y", to_frame.galcen.lat)
242-
rot_lon = GalileanRotation.from_euler("z", -to_frame.galcen.lon)
234+
rot_lat = Rotate.from_euler("y", to_frame.galcen.lat)
235+
rot_lon = Rotate.from_euler("z", -to_frame.galcen.lon)
243236
# extra roll away from the Galactic x-z plane
244-
roll = GalileanRotation.from_euler("x", to_frame.roll - to_frame.roll0)
237+
roll = Rotate.from_euler("x", to_frame.roll - to_frame.roll0)
245238
# construct transformation matrix
246239
R = (roll @ rot_lat @ rot_lon).simplify()
247240

248241
# Translation by Sun-Galactic center distance around x' and rotate about y'
249242
# to account for tilt due to Sun's height above the plane
250243
z_d = u.ustrip("", to_frame.z_sun / to_frame.galcen.distance) # [radian]
251-
H = GalileanRotation.from_euler("y", u.Quantity(jnp.asin(z_d), "rad"))
244+
H = Rotate.from_euler("y", u.Quantity(jnp.asin(z_d), "rad"))
252245

253246
# Post-rotation spatial offset to Galactic center.
254-
offset_q = GalileanSpatialTranslation(
255-
-to_frame.galcen.distance * jnp.asarray([1, 0, 0])
256-
)
247+
offset_q = Add(-to_frame.galcen.distance * jnp.asarray([1, 0, 0]))
257248

258249
# Post-rotation velocity offset
259-
offset_v = VelocityBoost(to_frame.galcen_v_sun)
250+
offset_v = Add(to_frame.galcen_v_sun)
260251

261252
# Total Operator
262253
return R | offset_q | H | offset_v

src/coordinax/_src/frames/coordinate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class Coordinate(AbstractCoordinate):
276276
... data=space,
277277
... frame=cx.frames.TransformedReferenceFrame(
278278
... cx.frames.Galactocentric(),
279-
... cx.ops.GalileanSpatialTranslation.from_([20, 0, 0], "kpc"),
279+
... cx.ops.GalileanOp.from_([20, 0, 0], "kpc"),
280280
... ),
281281
... )
282282

src/coordinax/_src/frames/example.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,7 @@
1111

1212
from .base import AbstractReferenceFrame
1313
from coordinax._src.frames import api
14-
from coordinax._src.operators import (
15-
GalileanBoost,
16-
GalileanRotation,
17-
GalileanSpatialTranslation,
18-
Identity,
19-
Pipe,
20-
)
14+
from coordinax._src.operators import Add, Identity, Pipe, Rotate
2115

2216

2317
@final
@@ -34,9 +28,9 @@ class Alice(AbstractReferenceFrame):
3428
>>> op = cxf.frame_transform_op(alice, bob)
3529
>>> print(op)
3630
Pipe((
37-
GalileanSpatialTranslation(<CartesianPos3D: (x, y, z) [km]
31+
Add(<CartesianPos3D: (x, y, z) [km]
3832
[100000 10000 0]>),
39-
GalileanBoost(<CartesianVel3D: (x, y, z) [m / s]
33+
Add(<CartesianVel3D: (x, y, z) [m / s]
4034
[2.698e+08 0.000e+00 0.000e+00]>)
4135
))
4236
@@ -62,9 +56,9 @@ class Bob(AbstractReferenceFrame):
6256
>>> op = cxf.frame_transform_op(alice, bob)
6357
>>> print(op)
6458
Pipe((
65-
GalileanSpatialTranslation(<CartesianPos3D: (x, y, z) [km]
59+
Add(<CartesianPos3D: (x, y, z) [km]
6660
[100000 10000 0]>),
67-
GalileanBoost(<CartesianVel3D: (x, y, z) [m / s]
61+
Add(<CartesianVel3D: (x, y, z) [m / s]
6862
[2.698e+08 0.000e+00 0.000e+00]>)
6963
))
7064
@@ -117,9 +111,9 @@ def frame_transform_op(from_frame: Alice, to_frame: FriendOfAlice, /) -> Pipe:
117111
>>> op = cx.frames.frame_transform_op(alice, friend)
118112
>>> print(op)
119113
Pipe((
120-
GalileanSpatialTranslation(<CartesianPos3D: (x, y, z) [m]
114+
Add(<CartesianPos3D: (x, y, z) [m]
121115
[10 0 0]>),
122-
GalileanRotation([[ 0. -0.99999994 0. ]
116+
Rotate([[ 0. -0.99999994 0. ]
123117
[ 0.99999994 0. 0. ]
124118
[ 0. 0. 0.99999994]])
125119
))
@@ -133,8 +127,8 @@ def frame_transform_op(from_frame: Alice, to_frame: FriendOfAlice, /) -> Pipe:
133127
))
134128
135129
"""
136-
shift = GalileanSpatialTranslation.from_([10, 0, 0], "m")
137-
rotation = GalileanRotation.from_euler("Z", u.Quantity(90, "deg"))
130+
shift = Add.from_([10, 0, 0], "m")
131+
rotation = Rotate.from_euler("Z", u.Quantity(90, "deg"))
138132
return shift | rotation
139133

140134

@@ -154,9 +148,9 @@ def frame_transform_op(from_frame: Alice, to_frame: Bob, /) -> Pipe:
154148
>>> op = cxf.frame_transform_op(alice, bob)
155149
>>> print(op)
156150
Pipe((
157-
GalileanSpatialTranslation(<CartesianPos3D: (x, y, z) [km]
151+
Add(<CartesianPos3D: (x, y, z) [km]
158152
[100000 10000 0]>),
159-
GalileanBoost(<CartesianVel3D: (x, y, z) [m / s]
153+
Add(<CartesianVel3D: (x, y, z) [m / s]
160154
[2.698e+08 0.000e+00 0.000e+00]>)
161155
))
162156
@@ -175,8 +169,8 @@ def frame_transform_op(from_frame: Alice, to_frame: Bob, /) -> Pipe:
175169
))
176170
177171
"""
178-
shift = GalileanSpatialTranslation.from_([100_000, 10_000, 0], "km")
179-
boost = GalileanBoost.from_([269_813_212.2, 0, 0], "m/s")
172+
shift = Add.from_([100_000, 10_000, 0], "km")
173+
boost = Add.from_([269_813_212.2, 0, 0], "m/s")
180174
return shift | boost
181175

182176

@@ -200,10 +194,10 @@ def frame_transform_op(
200194
>>> op = cx.frames.frame_transform_op(friend, alice)
201195
>>> print(op)
202196
Pipe((
203-
GalileanRotation([[ 0. 0.99999994 0. ]
197+
Rotate([[ 0. 0.99999994 0. ]
204198
[-0.99999994 0. 0. ]
205199
[ 0. 0. 0.99999994]]),
206-
GalileanSpatialTranslation(<CartesianPos3D: (x, y, z) [m]
200+
Add(<CartesianPos3D: (x, y, z) [m]
207201
[-10 0 0]>)
208202
))
209203

src/coordinax/_src/frames/register_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
__all__: tuple[str, ...] = ()
44

55

6+
from plum import dispatch
7+
68
from dataclassish import replace
79

810
from .coordinate import Coordinate
911
from coordinax._src.operators import AbstractOperator
1012

1113

12-
@AbstractOperator.__call__.dispatch # type: ignore[misc]
13-
def call(self: AbstractOperator, x: Coordinate, /) -> Coordinate:
14+
@dispatch
15+
def operate(self: AbstractOperator, obj: Coordinate, /) -> Coordinate:
1416
"""Apply the operator to a coordinate.
1517
1618
Examples
@@ -25,7 +27,7 @@ def call(self: AbstractOperator, x: Coordinate, /) -> Coordinate:
2527
frame=ICRS()
2628
)
2729
28-
>>> op = cx.ops.GalileanSpatialTranslation.from_([-1, -1, -1], "kpc")
30+
>>> op = cx.ops.GalileanOp.from_([-1, -1, -1], "kpc")
2931
3032
>>> new_coord = op(coord)
3133
>>> print(new_coord.data["length"])
@@ -34,4 +36,4 @@ def call(self: AbstractOperator, x: Coordinate, /) -> Coordinate:
3436
3537
"""
3638
# TODO: take the frame into account
37-
return replace(x, data=self(x.data))
39+
return replace(obj, data=self(obj.data))

src/coordinax/_src/frames/register_primitives.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dataclassish import replace
1010

1111
from .coordinate import Coordinate
12+
from coordinax._src.vectors.base_pos.core import AbstractPos
1213

1314

1415
@register(jax.lax.neg_p)
@@ -33,3 +34,26 @@ def neg_p_coord(x: Coordinate, /) -> Coordinate:
3334
3435
"""
3536
return replace(x, data=-x.data)
37+
38+
39+
@register(jax.lax.add_p)
40+
def add_p_coord_pos(x: Coordinate, y: AbstractPos, /) -> Coordinate:
41+
r"""Add a position vector to a coordinate.
42+
43+
To understand this operation, let's consider a phase-space point $(x, v) \in
44+
\mathbb{R}^3\times\mathbb{R}^3$ consisting of a position and a velocity. A
45+
pure spatial translation is the map $T_{\Delta x} : (x,v) \mapsto (x+\Delta
46+
x,\ v)$, i.e. only the position is shifted; velocity is unchanged.
47+
48+
"""
49+
# Get the Cartesian class for the coordinate's position
50+
cart_cls = y.cartesian_type
51+
# Convert the coordinate to that class. This changes the position, but also
52+
# the other components, e.g. the velocity.
53+
data = dict(x.data.vconvert(cart_cls))
54+
# Now add the position vector to the position component only
55+
data = replace(data, length=data["length"] + y)
56+
# Transform back to the original vector types
57+
# data.vconvert() # TODO: all original types
58+
# Reconstruct the Coordinate
59+
return Coordinate(data, frame=x.frame)

0 commit comments

Comments
 (0)