|
13 | 13 | import time |
14 | 14 | from collections.abc import Callable |
15 | 15 | from functools import wraps |
16 | | -from typing import ParamSpec, TypeVar |
| 16 | +from typing import Any, ParamSpec, TypeVar |
17 | 17 |
|
18 | 18 | import numpy as np |
19 | 19 | import pyvista as pv |
20 | 20 | from ansys.mapdl.core import Mapdl |
21 | 21 | from pydantic import BaseModel, Field |
22 | | -from tesseract_core.runtime import Array, Differentiable, Float32, Int32 |
| 22 | +from tesseract_core.runtime import Array, Differentiable, Float32, Int32, ShapeDType |
23 | 23 |
|
24 | 24 | # Set up module logger |
25 | 25 | logger = logging.getLogger(__name__) |
@@ -110,7 +110,7 @@ class InputSchema(BaseModel): |
110 | 110 | ) |
111 | 111 |
|
112 | 112 | log_level: str = Field( |
113 | | - default="INFO", |
| 113 | + default="WARNING", |
114 | 114 | description="Logging level for output messages (DEBUG, INFO, WARNING, ERROR).", |
115 | 115 | ) |
116 | 116 |
|
@@ -522,9 +522,11 @@ def _calculate_sensitivity(self) -> None: |
522 | 522 | inverse_rho = np.nan_to_num(1 / self.rho.flatten(), nan=0.0) |
523 | 523 | self.sensitivity = -2.0 * self.p * inverse_rho * self.strain_energy.flatten() |
524 | 524 |
|
525 | | - # TODO improve this cache? |
526 | | - # stash the sensitivity s.t. it may be loaded in the vjp |
527 | | - np.save("sensitivity.npy", self.sensitivity) |
| 525 | + # Cache sensitivity in temporary directory for use in VJP |
| 526 | + import tempfile |
| 527 | + |
| 528 | + sensitivity_path = os.path.join(tempfile.gettempdir(), "sensitivity.npy") |
| 529 | + np.save(sensitivity_path, self.sensitivity) |
528 | 530 |
|
529 | 531 | @log_timing |
530 | 532 | def _create_pvmesh(self) -> None: |
@@ -616,16 +618,55 @@ def apply(inputs: InputSchema) -> OutputSchema: |
616 | 618 | return solver.solve() |
617 | 619 |
|
618 | 620 |
|
619 | | -# TODO |
620 | | -# def vector_jacobian_product( |
621 | | -# inputs: InputSchema, |
622 | | -# vjp_inputs: set[str], |
623 | | -# vjp_outputs: set[str], |
624 | | -# cotangent_vector: dict[str, Any], |
625 | | -# ): |
626 | | -# pass |
627 | | -# |
628 | | -# |
629 | | -# def abstract_eval(abstract_inputs): |
630 | | -# """Calculate output shape of apply from the shape of its inputs.""" |
631 | | -# return {"compliance": ShapeDType(shape=(), dtype="float32")} |
| 621 | +def vector_jacobian_product( |
| 622 | + inputs: InputSchema, |
| 623 | + vjp_inputs: set[str], |
| 624 | + vjp_outputs: set[str], |
| 625 | + cotangent_vector: dict[str, Any], |
| 626 | +) -> dict[str, Any]: |
| 627 | + """Compute vector-Jacobian product for backpropagation through ANSYS solve. |
| 628 | +
|
| 629 | + The sensitivity (dcompliance/drho) is already computed and cached during |
| 630 | + the forward pass in _calculate_sensitivity(), so we just need to load it |
| 631 | + and multiply by the upstream gradient. |
| 632 | +
|
| 633 | + Args: |
| 634 | + inputs: Original inputs to the forward pass (InputSchema) |
| 635 | + vjp_inputs: Set of input names we need gradients for (e.g., {"rho"}) |
| 636 | + vjp_outputs: Set of output names being differentiated (e.g., {"compliance"}) |
| 637 | + cotangent_vector: Upstream gradients from the loss function |
| 638 | + (e.g., {"compliance": dloss/dcompliance}) |
| 639 | +
|
| 640 | + Returns: |
| 641 | + Dictionary mapping input names to their gradients |
| 642 | + (e.g., {"rho": dloss/∂rho}) |
| 643 | +
|
| 644 | + """ |
| 645 | + gradients = {} |
| 646 | + assert vjp_inputs == {"rho"} |
| 647 | + assert vjp_outputs == {"compliance"} |
| 648 | + |
| 649 | + # Load the cached sensitivity (∂compliance/∂rho) from temporary directory |
| 650 | + # This was computed and saved in _calculate_sensitivity() |
| 651 | + import tempfile |
| 652 | + |
| 653 | + sensitivity_path = os.path.join(tempfile.gettempdir(), "sensitivity.npy") |
| 654 | + sensitivity = np.load(sensitivity_path) |
| 655 | + |
| 656 | + # Clean up the temporary file after loading |
| 657 | + try: |
| 658 | + os.unlink(sensitivity_path) |
| 659 | + except FileNotFoundError: |
| 660 | + pass # Already deleted, no problem |
| 661 | + |
| 662 | + grad_rho_flat = cotangent_vector["compliance"] * sensitivity |
| 663 | + |
| 664 | + # Reshape to match input rho shape |
| 665 | + gradients["rho"] = grad_rho_flat.reshape(inputs.rho.shape) |
| 666 | + |
| 667 | + return gradients |
| 668 | + |
| 669 | + |
| 670 | +def abstract_eval(abstract_inputs): |
| 671 | + """Calculate output shape of apply from the shape of its inputs.""" |
| 672 | + return {"compliance": ShapeDType(shape=(), dtype="float32")} |
0 commit comments