Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: linting-and-code-formatting

on:
workflow_dispatch:
pull_request:
branches:
- main
types: [opened, synchronize, reopened]
push:

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --upgrade .[dev]
# Uncomment when internal linter bug has been fixed
# - name: pyink
# run: |
# pyink --check torax
- name: isort
run: |
isort --check-only torax
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ documentation = "https://github.com/google-deepmind/torax/blob/main/README.md"
[project.optional-dependencies]
# Installed through `pip install -e .[dev]`
dev = [
"isort",
"pytest",
"pytest-xdist",
"pytest-shard",
Expand All @@ -76,6 +77,10 @@ tutorial = [
"notebook",
]

[tool.isort]
profile = "google"
known_third_party = 'torax'

[tool.pyink]
# Formatting configuration to follow Google style-guide
line-length = 80
Expand Down
5 changes: 2 additions & 3 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import logging
import os

import jax

# pylint: disable=g-importing-member

import jax
from torax._src import version
from torax._src.config.config_loader import build_torax_config_from_file
from torax._src.config.config_loader import import_module
Expand Down Expand Up @@ -69,4 +67,5 @@ def set_jax_precision():
def log_jax_backend():
logging.info('JAX running on a default %s backend', jax.default_backend())


set_jax_precision()
2 changes: 1 addition & 1 deletion torax/_src/config/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import types
import typing
from typing import Any, Literal, TypeAlias

from torax._src.plotting import plotruns_lib
from torax._src.torax_pydantic import model_config


ExampleConfig: TypeAlias = Literal[
'basic_config', 'iterhybrid_predictor_corrector', 'iterhybrid_rampup'
]
Expand Down
1 change: 1 addition & 0 deletions torax/_src/config/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class StaticNumerics:

For definitions see `Numerics`.
"""

evolve_ion_heat: bool
evolve_electron_heat: bool
evolve_current: bool
Expand Down
2 changes: 2 additions & 0 deletions torax/_src/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""Profile condition parameters used throughout TORAX simulations."""
import dataclasses
from typing import Callable, Final

import chex
import jax
import numpy as np
import pydantic
from torax._src import array_typing
from torax._src.torax_pydantic import torax_pydantic
from typing_extensions import Self

# pylint: disable=invalid-name

# Order of magnitude validations to catch common config errors.
Expand Down
2 changes: 2 additions & 0 deletions torax/_src/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from torax._src.sources import runtime_params as sources_params
from torax._src.torax_pydantic import torax_pydantic
from torax._src.transport_model import runtime_params as transport_model_params

# Many of the variables follow scientific or mathematical notation, so disable
# pylint complaints.
# pylint: disable=invalid-name
Expand Down Expand Up @@ -107,6 +108,7 @@ class StaticRuntimeParamsSlice:
TODO(b/335596447): Add function to help users detect whether their
change in config will trigger a recompile.
"""

# Solver-specific static runtime params.
solver: solver_params.StaticRuntimeParams
# Mapping of source name to source-specific static runtime params.
Expand Down
1 change: 0 additions & 1 deletion torax/_src/config/runtime_validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from torax._src import constants
from torax._src.torax_pydantic import torax_pydantic


_TOLERANCE: Final[float] = 1e-6


Expand Down
1 change: 1 addition & 0 deletions torax/_src/config/tests/config_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import pathlib
import typing

from absl.testing import absltest
from absl.testing import parameterized
from torax._src.config import config_loader
Expand Down
1 change: 1 addition & 0 deletions torax/_src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import dataclasses
from typing import Final, Mapping

import chex
import immutabledict
import jax
Expand Down
1 change: 1 addition & 0 deletions torax/_src/core_profiles/convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
from typing import Final, Mapping, Tuple

import immutabledict
from torax._src import state
from torax._src.fvm import cell_variable
Expand Down
1 change: 1 addition & 0 deletions torax/_src/core_profiles/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Functions used for initializing core profiles."""

import dataclasses

import jax
from jax import numpy as jnp
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions torax/_src/core_profiles/tests/boundary_conditions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions torax/_src/core_profiles/tests/getters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
Expand Down
1 change: 1 addition & 0 deletions torax/_src/core_profiles/updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""
import dataclasses
import functools

import jax
from jax import numpy as jnp
from torax._src import array_typing
Expand Down
1 change: 0 additions & 1 deletion torax/_src/fvm/block_1d_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import jax


# An optional argument, consisting of a 2D matrix of nested tuples, with each
# leaf being either None or a JAX Array. Used to define block matrices.
# examples:
Expand Down
1 change: 0 additions & 1 deletion torax/_src/fvm/discrete_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torax._src.fvm import convection_terms
from torax._src.fvm import diffusion_terms


AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs

Expand Down
1 change: 1 addition & 0 deletions torax/_src/fvm/fvm_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Conversions utilities for fvm objects."""

import dataclasses

import jax
from jax import numpy as jnp
from torax._src import state
Expand Down
1 change: 1 addition & 0 deletions torax/_src/fvm/implicit_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import dataclasses
import functools

import jax
from jax import numpy as jnp
from torax._src import jax_utils
Expand Down
4 changes: 1 addition & 3 deletions torax/_src/fvm/residual_and_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def theta_method_block_residual(
return residual


# pylint: disable=missing-function-docstring
@functools.partial(
jax_utils.jit,
static_argnames=[
Expand All @@ -315,9 +314,8 @@ def theta_method_block_residual(
'neoclassical_models',
],
)
def theta_method_block_jacobian(*args, **kwargs):
def theta_method_block_jacobian(*args, **kwargs): # pylint: disable=missing-function-docstring
return jax.jacfwd(theta_method_block_residual)(*args, **kwargs)
# pylint: enable=missing-function-docstring


@functools.partial(
Expand Down
5 changes: 2 additions & 3 deletions torax/_src/fvm/tests/calc_coeffs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy

from absl.testing import absltest
from absl.testing import parameterized
from torax._src.config import build_runtime_params
Expand Down Expand Up @@ -110,9 +111,7 @@ def create_coeffs_callback(
transport_model = torax_config.transport.build_transport_model()
evolving_names = tuple(['T_i'])
source_models = torax_config.sources.build_models()
neoclassical_models = (
torax_config.neoclassical.build_models()
)
neoclassical_models = torax_config.neoclassical.build_models()
static_runtime_params_slice = (
build_runtime_params.build_static_params_from_config(torax_config)
)
Expand Down
6 changes: 1 addition & 5 deletions torax/_src/fvm/tests/cell_variable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,7 @@ def test_eq_not_cell_variable(self):
),
)
def test_almost_equal(
self,
var1_kwargs,
var2_kwargs,
expected_almost_equal,
atol=1e-6
self, var1_kwargs, var2_kwargs, expected_almost_equal, atol=1e-6
):
var1 = cell_variable.CellVariable(**var1_kwargs)
var2 = cell_variable.CellVariable(**var2_kwargs)
Expand Down
1 change: 0 additions & 1 deletion torax/_src/geometry/geometry_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import scipy
import torax


# Internal import.
# Internal import.

Expand Down
2 changes: 2 additions & 0 deletions torax/_src/geometry/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import functools
import inspect
from typing import Annotated, Any, Literal, TypeAlias, TypeVar

import pydantic
from torax._src.geometry import circular_geometry
from torax._src.geometry import geometry
from torax._src.geometry import geometry_provider
from torax._src.geometry import standard_geometry
from torax._src.torax_pydantic import torax_pydantic
import typing_extensions

# Using invalid-name because we are using the same naming convention as the
# external physics implementations
# pylint: disable=invalid-name
Expand Down
4 changes: 2 additions & 2 deletions torax/_src/geometry/tests/geometry_provider_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import dataclasses

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses

from absl.testing import absltest
import numpy as np
from torax._src.geometry import geometry
Expand Down
4 changes: 2 additions & 2 deletions torax/_src/geometry/tests/geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def test_update_phibdot(self):
geo0_updated, geo1_updated = geometry.update_geometries_with_Phibdot(
dt=0.1, geo_t=geo0, geo_t_plus_dt=geo1
)
np.testing.assert_allclose(geo0_updated.Phi_b_dot, 10.)
np.testing.assert_allclose(geo1_updated.Phi_b_dot, 10.)
np.testing.assert_allclose(geo0_updated.Phi_b_dot, 10.0)
np.testing.assert_allclose(geo1_updated.Phi_b_dot, 10.0)


def _pint_face_to_cell(n_rho, face):
Expand Down
2 changes: 1 addition & 1 deletion torax/_src/geometry/tests/standard_geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_validate_fbt_data_invalid_shape(self, invalid_key, invalid_shape):
'deltau',
'deltal',
'kappa',
'FtPQ', # TODO(b/412965439) remove support for LY files w/o FtPVQ.
'FtPQ', # TODO(b/412965439) remove support for LY files w/o FtPVQ.
'zA',
't',
)
Expand Down
5 changes: 2 additions & 3 deletions torax/_src/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections.abc import Mapping
import enum
from typing import Final, Literal, TypeAlias

import chex
import jax
import jax.numpy as jnp
Expand All @@ -38,9 +39,7 @@ def _step_interpolation(xs: chex.Array, x: chex.Numeric) -> chex.Array:
# and return self.ys[k]. Subtracting 1 gives index k. Setting side='left'
# means that the step occurs whenever x > self.xs. Clipping is strictly
# necessary for the case where searchsorted returns index 0.
return jnp.clip(
jnp.searchsorted(xs, x, side='left') - 1, 0, xs.shape[0] - 1
)
return jnp.clip(jnp.searchsorted(xs, x, side='left') - 1, 0, xs.shape[0] - 1)


@enum.unique
Expand Down
1 change: 1 addition & 0 deletions torax/_src/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import os
from typing import Any, Callable, Literal, TypeVar

import chex
import equinox as eqx
import jax
Expand Down
Loading