Skip to content

Commit ae1354d

Browse files
committed
Add TGLFNNukaeaTransportModel
1 parent ad13527 commit ae1354d

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from fusion_surrogates.tglfnn_ukaea import config as tglfnn_ukaea_config
2+
from fusion_surrogates.tglfnn_ukaea import tglfnn_ukaea_model
3+
import jax
4+
import jax.numpy as jnp
5+
from torax._src import state
6+
from torax._src.config import runtime_params_slice
7+
from torax._src.geometry import geometry
8+
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
9+
from torax._src.transport_model import tglf_based_transport_model
10+
from torax._src.transport_model import transport_model as transport_model_lib
11+
12+
13+
class TGLFNNukaeaTransportModel(
14+
tglf_based_transport_model.TGLFBasedTransportModel
15+
):
16+
17+
def __init__(
18+
self,
19+
config_path: str,
20+
stats_path: str,
21+
efe_gb_pt: str,
22+
efi_gb_pt: str,
23+
pfi_gb_pt: str,
24+
):
25+
self._config_path = config_path
26+
self._stats_path = stats_path
27+
self._efe_gb_pt = efe_gb_pt
28+
self._efi_gb_pt = efi_gb_pt
29+
self._pfi_gb_pt = pfi_gb_pt
30+
31+
self.model = tglfnn_ukaea_model.TGLFNNukaeaModel(
32+
config=tglfnn_ukaea_config.TGLFNNukaeaModelConfig.load(config_path),
33+
stats=tglfnn_ukaea_config.TGLFNNukaeaModelStats.load(stats_path),
34+
)
35+
self.model.load_params(
36+
efe_gb_pt=efe_gb_pt, efi_gb_pt=efi_gb_pt, pfi_gb_pt=pfi_gb_pt
37+
)
38+
super().__init__()
39+
self._frozen = True
40+
41+
def _make_input_tensor(
42+
self,
43+
transport,
44+
geo,
45+
core_profiles,
46+
) -> (tglf_based_transport_model.TGLFInputs, jax.Array):
47+
tglf_inputs = self._prepare_tglf_inputs(transport, geo, core_profiles)
48+
49+
# Note: TGLFNN-ukaea uses a different definition of the magnetic shear
50+
# to TGLF. This is not the same as s_hat in s-alpha geometry.
51+
s_hat = (tglf_inputs.r_minor / tglf_inputs.q) ** 2 * tglf_inputs.q_prime
52+
tglfnn_inputs = jnp.stack(
53+
[
54+
tglf_inputs.RLNS_1,
55+
tglf_inputs.RLTS_1,
56+
tglf_inputs.RLTS_2,
57+
tglf_inputs.TAUS_2,
58+
tglf_inputs.RMIN_LOC,
59+
tglf_inputs.DRMAJDX_LOC,
60+
tglf_inputs.Q_LOC,
61+
s_hat,
62+
tglf_inputs.XNUE,
63+
tglf_inputs.KAPPA_LOC,
64+
tglf_inputs.S_KAPPA_LOC,
65+
tglf_inputs.DELTA_LOC,
66+
tglf_inputs.S_DELTA_LOC,
67+
tglf_inputs.BETAE,
68+
tglf_inputs.ZEFF,
69+
],
70+
axis=-1,
71+
)
72+
return tglf_inputs, tglfnn_inputs
73+
74+
def _call_implementation(
75+
self,
76+
transport_dynamic_runtime_params: tglf_based_transport_model.DynamicRuntimeParams,
77+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
78+
geo: geometry.Geometry,
79+
core_profiles: state.CoreProfiles,
80+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
81+
) -> transport_model_lib.TurbulentTransport:
82+
tglf_inputs, tglfnn_inputs = self._make_input_tensor(
83+
transport=transport_dynamic_runtime_params,
84+
geo=geo,
85+
core_profiles=core_profiles,
86+
)
87+
predictions = self.model.predict(tglfnn_inputs)
88+
89+
# TODO: expose variance output
90+
return self._make_core_transport(
91+
qi=predictions["efi_gb"][..., tglfnn_ukaea_config.MEAN_OUTPUT],
92+
qe=predictions["efe_gb"][..., tglfnn_ukaea_config.MEAN_OUTPUT],
93+
# TODO: TGLFNN outputs pfi, TORAX wants pfe
94+
pfe=predictions["pfi_gb"][..., tglfnn_ukaea_config.MEAN_OUTPUT],
95+
quasilinear_inputs=tglf_inputs,
96+
transport=transport_dynamic_runtime_params,
97+
geo=geo,
98+
core_profiles=core_profiles,
99+
# TODO: explain choices here
100+
gradient_reference_length=1,
101+
gyrobohm_flux_reference_length=1,
102+
)
103+
104+
def __hash__(self) -> int:
105+
combined_path = (
106+
self._config_path
107+
+ self._stats_path
108+
+ self._efe_gb_pt
109+
+ self._efi_gb_pt
110+
+ self._pfi_gb_pt
111+
)
112+
return hash(combined_path)
113+
114+
def __eq__(self, other) -> bool:
115+
return (
116+
self._config_path == other._config_path
117+
and self._stats_path == other._stats_path
118+
and self._efe_gb_pt == other._efe_gb_pt
119+
and self._efi_gb_pt == other._efi_gb_pt
120+
and self._pfi_gb_pt == other._pfi_gb_pt
121+
)

0 commit comments

Comments
 (0)