Skip to content

Conversation

@Dekermanjian
Copy link
Contributor

This is a draft proposal for #598

The idea is to handle each component separately using _set_{component} methods and all information are stored using data classes for easy mapping.

I believe this will simplify our tests of these components and will reduce redundancies where we have the same information spread across multiple sub-components like data_names and data_info.

@jessegrabowski let me know what you think I put a little notebook together to showcase the changes.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@Dekermanjian Dekermanjian marked this pull request as draft November 2, 2025 17:50
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great first pass, much cleaner than what we have now.

@jessegrabowski
Copy link
Member

We can also keep all of the existing properties like state_names, shock_names, state_dims, etc, but move them to the base class and just extract the requested info from the relevant Info objects.

@jessegrabowski jessegrabowski changed the title proposal for updating propogate_component_properties using data classes Represent statespace metadata with dataclasses Nov 7, 2025
@jessegrabowski
Copy link
Member

Reflecting on it, I am convinced this is the way to go. It's 1000x more ergonomic. I made some changes to your initial code to make the API more "dictionary like", and to reduce code duplication. I moved everything to statespace/core/properties.py, because this is ultimately going to replace what we have in both the core models and the components.

@Dekermanjian
Copy link
Contributor Author

@jessegrabowski, this is looking really cool! What can I do to help push this forward?

@jessegrabowski
Copy link
Member

jessegrabowski commented Nov 7, 2025

Delete the new regression_dataclass.py and simply refactor regression.py to use the new stuff.

We should keep your notebook with the plan to add it as a new example for the docs. Or it can be merged into the custom statespace notebook. So that should also be updated to import from the new properties.py file

@Dekermanjian
Copy link
Contributor Author

Perfect! I'll work on that today!! It is really looking cool!

SHOCK_DIM: ss_mod.shock_names,
SHOCK_AUX_DIM: ss_mod.shock_names,
}
ALL_STATE_COORD = Coord(dimension=ALL_STATE_DIM, labels=ss_mod.state_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be eliminated in favor of the method CoordsInfo.defaults_from_model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still stands. Is there anywhere we need this instead of just using CoordsInfo.defaults_from_model?

@Dekermanjian
Copy link
Contributor Author

@jessegrabowski, I agree with all of your comments above. I am going to start making those changes.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incomplete review, I'll continue tomorrow AM

Comment on lines 47 to 50
# if key in index:
# raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't happen here though, it should come up in merge or add right? And we handle it there with the allow_duplicates flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what happens is because our data classes are immutable the __post_init__ runs right after our merge/add because we always return new objects of the same dataclass and it see that there are duplicate keys even though the merge/add method had allowed them via allow_duplicates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be duplicate keys in the final result though. My understanding was that if we set allow_duplicates=False, there's essentially a runtime guard that we aren't trying to add a key that already exists (this will error). If True, we don't raise an error and overwrite the existing key, like a python dictionary.

@Dekermanjian
Copy link
Contributor Author

Hey @jessegrabowski, by switching a lot of the component attributes to properties I was able to simplify a good amount of downstream methods. If you don't mind taking a look at the current state of this before I go ahead and do the same with the rest of the SSM components.


self.coords_info = CoordInfo(coords=[regression_state_coord, endogenous_state_coord])

def populate_component_properties(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method won't be unique to regression right? We will want to move it up to the base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski, in the base class there is a populate_component_properties method that raises a NotImplemented. Did you want to replace that with a generic method that sets _set_<foo> for the 2 defaults (shocks and states) that we provide?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generic populate_component_propterties should call everything. For any setter that raises by default, we should always implement it (even if we just have def _set_foo(self): pass). Then it's really obvious that this component doesn't have a specific property. Alternatively we can set an empty Info.

For optional properties, we should make sure it's harmless to call e.g. _set_data_info() for components that don't use it.

@jessegrabowski
Copy link
Member

Yeah it looks really great! Go ahead and do the others. Excited to get this over the finish line

@Dekermanjian
Copy link
Contributor Author

Yeah it looks really great! Go ahead and do the others. Excited to get this over the finish line

Yeah, this is going to be pretty cool! I am going to commit + push the components one by one as I complete them. I will also let you know here once I get all of them done so that if you prefer reviewing everything all at once.

@Dekermanjian
Copy link
Contributor Author

Alright! @jessegrabowski, I refactored all the structural components to use the new dataclass architecture. All tests under ./tests/statespace/models/structural/components/ pass. We will need to refactor the models (ETS, S/VARIMA/X, etc) to use the new architecture.

@jessegrabowski
Copy link
Member

jessegrabowski commented Dec 29, 2025

Amazing! How is your impression of it after refactoring everything? Does it seem a bit more readable? Any sharp edges that are still confusing?

Also it looks like you need to rebase and run pre-commit

@Dekermanjian
Copy link
Contributor Author

Dekermanjian commented Dec 29, 2025

I think it is more readable. There are some repeated code that I am trying to figure out how to make less repeatable. These are the k_<foo> variables that you will see at the start of many of the _set_<foo> methods. I am sure once you take a look in your review we can iron them out.

@Dekermanjian
Copy link
Contributor Author

Dekermanjian commented Dec 29, 2025

Also it looks like you need to rebase and run pre-commit

Yes, I definitely need to rebase, but I am pretty sure the pre-commit did run. hmm that is odd.

…uplicate with warning

2. removed unnecessary imports from __init__ after deleting regression_dataclass
3. updated components and structural classes to only utilize dataclasses and pull other objects from <foo>_info dataclasses
4. updated tests to conform to dataclass api
2. created tests for add and merge methods
3. added utility to convert from snake to pascal and integrated it in error messaging
…_duplicates is False

2. converted component attributes into properties
3. removed _combine_property method
4. removed redundant observed_states property
5. fixed indentation bug
2. Added TensorVariable and TensorData properties for use with make_and_register_variable/data
3. Updated regression component to use TensorData property
@Dekermanjian Dekermanjian force-pushed the ssm_populate_component_properties branch from c80a065 to f08f7d6 Compare December 29, 2025 21:10
@Dekermanjian
Copy link
Contributor Author

@jessegrabowski, I rebased and ran pre-commit again but the pre-commit says everything passed.

Screenshot 2025-12-29 at 2 12 04 PM

@jessegrabowski
Copy link
Member

@jessegrabowski, I rebased and ran pre-commit again but the pre-commit says everything passed.

Looks like the CI is happy now :)

@@ -0,0 +1,257 @@
from __future__ import annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The robot always adds this -- why do we need it?

Comment on lines 47 to 50
# if key in index:
# raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be duplicate keys in the final result though. My understanding was that if we set allow_duplicates=False, there's essentially a runtime guard that we aren't trying to add a key that already exists (this will error). If True, we don't raise an error and overwrite the existing key, like a python dictionary.

SHOCK_DIM: ss_mod.shock_names,
SHOCK_AUX_DIM: ss_mod.shock_names,
}
ALL_STATE_COORD = Coord(dimension=ALL_STATE_DIM, labels=ss_mod.state_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still stands. Is there anywhere we need this instead of just using CoordsInfo.defaults_from_model?

)


@pytest.mark.filterwarnings("ignore:Duplicate names found:UserWarning")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicates shouldn't ever warn, it should raise (if disallowed) or silently accept it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the argument name allow_duplicates to overwrite_duplicates because that is really what is happening. I was able to impose the logic of raising if there is a duplicate key. I will remove the warning message of overwriting duplicate names which is required when we share states.

2. added raise value error on duplicate keys in Info class
3. updated arg name from allow_duplicates to overwrite_duplicates
4. updated component child classes so that graph construction method it as the bottom of the file
5. updated setter methods in component child classes so that they can be called in any order
6. removed warning when overwriting duplicate names in info class
7. reduced complexity by using parameter containers in if blocks
8. Switched merge to add methods for TensorVariable and TensorData construction
9. renamed dataclass TensorVariable and TensoreVariableInfo due to conflict with pt.TensorVariable
@Dekermanjian
Copy link
Contributor Author

Dekermanjian commented Dec 30, 2025

@jessegrabowski, I had to rename the data classes TensorVariable and TensorVariableInfo to PyTensorVariable and PyTensorVariableInfo because there was a conflict with Pytensor's pt.TensorVariable class.

I preferred the older names but didn't want to risk any issues.

@jessegrabowski
Copy link
Member

@jessegrabowski, I had to rename the data classes TensorVariable and TensorVariableInfo to PyTensorVariable and PyTensorVariableInfo because there was a conflict with Pytensor's pt.TensorVariable class.

What about SymbolicTensor and SymbolicTensorInfo ?

@Dekermanjian
Copy link
Contributor Author

@jessegrabowski, I had to rename the data classes TensorVariable and TensorVariableInfo to PyTensorVariable and PyTensorVariableInfo because there was a conflict with Pytensor's pt.TensorVariable class.

What about SymbolicTensor and SymbolicTensorInfo ?

That is much better! I will switch it to that.

…bolicData

2. removed a remnant of _name_to_vars being used in make_and_register_data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment