Skip to content

Commit dfc4732

Browse files
committed
improved test, made notes for future ref
1 parent 9b6e872 commit dfc4732

File tree

2 files changed

+212
-141
lines changed

2 files changed

+212
-141
lines changed

examples/ansys/pymapdl_tess/tesseract_api.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
import time
1414
from collections.abc import Callable
1515
from functools import wraps
16-
from typing import ParamSpec, TypeVar
16+
from typing import Any, ParamSpec, TypeVar
1717

1818
import numpy as np
1919
import pyvista as pv
2020
from ansys.mapdl.core import Mapdl
2121
from pydantic import BaseModel, Field
22-
from tesseract_core.runtime import Array, Differentiable, Float32, Int32
22+
from tesseract_core.runtime import Array, Differentiable, Float32, Int32, ShapeDType
2323

2424
# Set up module logger
2525
logger = logging.getLogger(__name__)
@@ -110,7 +110,7 @@ class InputSchema(BaseModel):
110110
)
111111

112112
log_level: str = Field(
113-
default="INFO",
113+
default="WARNING",
114114
description="Logging level for output messages (DEBUG, INFO, WARNING, ERROR).",
115115
)
116116

@@ -522,9 +522,11 @@ def _calculate_sensitivity(self) -> None:
522522
inverse_rho = np.nan_to_num(1 / self.rho.flatten(), nan=0.0)
523523
self.sensitivity = -2.0 * self.p * inverse_rho * self.strain_energy.flatten()
524524

525-
# TODO improve this cache?
526-
# stash the sensitivity s.t. it may be loaded in the vjp
527-
np.save("sensitivity.npy", self.sensitivity)
525+
# Cache sensitivity in temporary directory for use in VJP
526+
import tempfile
527+
528+
sensitivity_path = os.path.join(tempfile.gettempdir(), "sensitivity.npy")
529+
np.save(sensitivity_path, self.sensitivity)
528530

529531
@log_timing
530532
def _create_pvmesh(self) -> None:
@@ -616,16 +618,55 @@ def apply(inputs: InputSchema) -> OutputSchema:
616618
return solver.solve()
617619

618620

619-
# TODO
620-
# def vector_jacobian_product(
621-
# inputs: InputSchema,
622-
# vjp_inputs: set[str],
623-
# vjp_outputs: set[str],
624-
# cotangent_vector: dict[str, Any],
625-
# ):
626-
# pass
627-
#
628-
#
629-
# def abstract_eval(abstract_inputs):
630-
# """Calculate output shape of apply from the shape of its inputs."""
631-
# return {"compliance": ShapeDType(shape=(), dtype="float32")}
621+
def vector_jacobian_product(
622+
inputs: InputSchema,
623+
vjp_inputs: set[str],
624+
vjp_outputs: set[str],
625+
cotangent_vector: dict[str, Any],
626+
) -> dict[str, Any]:
627+
"""Compute vector-Jacobian product for backpropagation through ANSYS solve.
628+
629+
The sensitivity (dcompliance/drho) is already computed and cached during
630+
the forward pass in _calculate_sensitivity(), so we just need to load it
631+
and multiply by the upstream gradient.
632+
633+
Args:
634+
inputs: Original inputs to the forward pass (InputSchema)
635+
vjp_inputs: Set of input names we need gradients for (e.g., {"rho"})
636+
vjp_outputs: Set of output names being differentiated (e.g., {"compliance"})
637+
cotangent_vector: Upstream gradients from the loss function
638+
(e.g., {"compliance": dloss/dcompliance})
639+
640+
Returns:
641+
Dictionary mapping input names to their gradients
642+
(e.g., {"rho": dloss/∂rho})
643+
644+
"""
645+
gradients = {}
646+
assert vjp_inputs == {"rho"}
647+
assert vjp_outputs == {"compliance"}
648+
649+
# Load the cached sensitivity (∂compliance/∂rho) from temporary directory
650+
# This was computed and saved in _calculate_sensitivity()
651+
import tempfile
652+
653+
sensitivity_path = os.path.join(tempfile.gettempdir(), "sensitivity.npy")
654+
sensitivity = np.load(sensitivity_path)
655+
656+
# Clean up the temporary file after loading
657+
try:
658+
os.unlink(sensitivity_path)
659+
except FileNotFoundError:
660+
pass # Already deleted, no problem
661+
662+
grad_rho_flat = cotangent_vector["compliance"] * sensitivity
663+
664+
# Reshape to match input rho shape
665+
gradients["rho"] = grad_rho_flat.reshape(inputs.rho.shape)
666+
667+
return gradients
668+
669+
670+
def abstract_eval(abstract_inputs):
671+
"""Calculate output shape of apply from the shape of its inputs."""
672+
return {"compliance": ShapeDType(shape=(), dtype="float32")}

0 commit comments

Comments
 (0)