Skip to content

Commit d4e64c4

Browse files
authored
Merge pull request #153 from dynamicslab/weak_optimization
Enhanced subdomain integration for the weak form library
2 parents a63ffcf + 0748ec5 commit d4e64c4

23 files changed

+1172
-697
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
hooks:
1212
- id: reorder-python-imports
1313
- repo: https://github.com/ambv/black
14-
rev: 21.5b1
14+
rev: 22.3.0
1515
hooks:
1616
- id: black
1717
- repo: https://gitlab.com/pycqa/flake8

examples/12_weakform_SINDy_examples.ipynb

Lines changed: 296 additions & 145 deletions
Large diffs are not rendered by default.

examples/1_feature_overview.ipynb

Lines changed: 27 additions & 19 deletions
Large diffs are not rendered by default.

pysindy/feature_library/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def __init__(
327327
)
328328
self.libraries_ = libraries
329329
self.inputs_per_library_ = inputs_per_library
330+
for lib in self.libraries_:
331+
if hasattr(lib, "spatiotemporal_grid"):
332+
if lib.spatiotemporal_grid is not None:
333+
self.n_samples = lib.K
334+
self.spatiotemporal_grid = lib.spatiotemporal_grid
330335

331336
def _combinations(self, lib_i, lib_j):
332337
"""
@@ -422,9 +427,12 @@ def transform(self, x):
422427
generated from applying the custom functions to the inputs.
423428
424429
"""
430+
n_samples = x.shape[0]
425431
for lib in self.libraries_:
426432
check_is_fitted(lib)
427-
n_samples = x.shape[0]
433+
if hasattr(lib, "spatiotemporal_grid"):
434+
if lib.spatiotemporal_grid is not None: # check if weak form
435+
n_samples = self.n_samples
428436

429437
# preallocate matrix
430438
xp = np.zeros((n_samples, self.n_output_features_))

pysindy/feature_library/generalized_library.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def transform(self, x):
252252

253253
n_samples, n_features = x.shape
254254

255+
if isinstance(self.libraries_[0], WeakPDELibrary):
256+
n_samples = self.libraries_[0].K * self.libraries_[0].num_trajectories
257+
255258
if float(__version__[:3]) >= 1.0:
256259
n_input_features = self.n_features_in_
257260
else:

0 commit comments

Comments
 (0)