1717
1818
1919def test_activator_3_layers_segmentation_function (simple_model_3_layers , simple_grid_3d_more_points_grid ):
20- interpolation_input = simple_model_3_layers [0 ]
21- options = simple_model_3_layers [1 ]
22- data_shape = simple_model_3_layers [2 ].tensors_structure
23- grid = dataclasses .replace (simple_grid_3d_more_points_grid )
24- interpolation_input .set_temp_grid (grid )
25-
26- ids = np .array ([1 , 2 , 3 , 4 ])
20+ Z_x , grid , ids_block , interpolation_input = _run_test (
21+ backend = AvailableBackends .numpy ,
22+ ids = np .array ([1 , 20 , 3 , 4 ]),
23+ simple_grid_3d_more_points_grid = simple_grid_3d_more_points_grid ,
24+ simple_model_3_layers = simple_model_3_layers
25+ )
2726
28- interp_input : SolverInput = input_preprocess ( data_shape , interpolation_input )
29- weights = _solve_interpolation ( interp_input , options . kernel_options )
27+ if plot :
28+ _plot_continious ( grid , ids_block , interpolation_input )
3029
31- exported_fields = _evaluate_sys_eq (interp_input , weights , options )
32- exported_fields .set_structure_values (
33- reference_sp_position = data_shape .reference_sp_position ,
34- slice_feature = interpolation_input .slice_feature ,
35- grid_size = interpolation_input .grid .len_all_grids )
3630
37- Z_x : np .ndarray = exported_fields .scalar_field
38- sasp = exported_fields .scalar_field_at_surface_points
39- ids = np .array ([1 , 20 , 3 , 4 ])
31+ def test_activator_3_layers_segmentation_function_II (simple_model_3_layers , simple_grid_3d_more_points_grid ):
32+ Z_x , grid , ids_block , interpolation_input = _run_test (
33+ backend = AvailableBackends .numpy ,
34+ ids = np .array ([1 , 2 , 3 , 4 ]),
35+ simple_grid_3d_more_points_grid = simple_grid_3d_more_points_grid ,
36+ simple_model_3_layers = simple_model_3_layers
37+ )
4038
41- print (Z_x , Z_x .shape [0 ])
42- print (sasp )
39+ BackendTensor .change_backend_gempy (AvailableBackends .numpy )
4340
44- ids_block = activate_formation_block (
45- exported_fields = exported_fields ,
46- ids = ids ,
47- sigmoid_slope = 500 * 4
48- )[0 , :- 7 ]
41+ if plot :
42+ _plot_continious (grid , ids_block , interpolation_input )
4943
50- if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
51- ids_block = ids_block .detach ().numpy ()
52- Z_x = Z_x .detach ().numpy ()
53- interpolation_input .surface_points .sp_coords = interpolation_input .surface_points .sp_coords .detach ().numpy ()
5444
45+ def test_activator_3_layers_segmentation_function_torch (simple_model_3_layers , simple_grid_3d_more_points_grid ):
46+ Z_x , grid , ids_block , interpolation_input = _run_test (
47+ backend = AvailableBackends .PYTORCH ,
48+ ids = np .array ([1 , 2 , 3 , 4 ]),
49+ simple_grid_3d_more_points_grid = simple_grid_3d_more_points_grid ,
50+ simple_model_3_layers = simple_model_3_layers
51+ )
5552 if plot :
5653 _plot_continious (grid , ids_block , interpolation_input )
5754
5855
59- def test_activator_3_layers_segmentation_function_II ( simple_model_3_layers , simple_grid_3d_more_points_grid ):
56+ def _run_test ( backend , ids , simple_grid_3d_more_points_grid , simple_model_3_layers ):
6057 interpolation_input = simple_model_3_layers [0 ]
6158 options = simple_model_3_layers [1 ]
6259 data_shape = simple_model_3_layers [2 ].tensors_structure
6360 grid = dataclasses .replace (simple_grid_3d_more_points_grid )
6461 interpolation_input .set_temp_grid (grid )
65-
6662 interp_input : SolverInput = input_preprocess (data_shape , interpolation_input )
6763 weights = _solve_interpolation (interp_input , options .kernel_options )
68-
6964 exported_fields = _evaluate_sys_eq (interp_input , weights , options )
7065 exported_fields .set_structure_values (
7166 reference_sp_position = data_shape .reference_sp_position ,
7267 slice_feature = interpolation_input .slice_feature ,
7368 grid_size = interpolation_input .grid .len_all_grids )
74-
7569 Z_x : np .ndarray = exported_fields .scalar_field
7670 sasp = exported_fields .scalar_field_at_surface_points
77- ids = np .array ([1 , 2 , 3 , 4 ])
78-
7971 print (Z_x , Z_x .shape [0 ])
8072 print (sasp )
81-
82- BackendTensor .change_backend_gempy (AvailableBackends .numpy )
73+ BackendTensor .change_backend_gempy (backend )
8374 ids_block = activate_formation_block (
8475 exported_fields = exported_fields ,
8576 ids = ids ,
8677 sigmoid_slope = 500 * 4
8778 )[0 , :- 7 ]
88-
89- BackendTensor .change_backend_gempy (AvailableBackends .numpy )
90- if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
91- ids_block = ids_block .detach ().numpy ()
92- Z_x = Z_x .detach ().numpy ()
93- interpolation_input .surface_points .sp_coords = interpolation_input .surface_points .sp_coords .detach ().numpy ()
94-
95- if plot :
96- _plot_continious (grid , ids_block , interpolation_input )
79+ return Z_x , grid , ids_block , interpolation_input
9780
9881
9982def _plot_continious (grid , ids_block , interpolation_input ):
@@ -113,5 +96,3 @@ def _plot_continious(grid, ids_block, interpolation_input):
11396 plt .plot (xyz [:, 0 ], xyz [:, 2 ], "o" )
11497 plt .colorbar ()
11598 plt .show ()
116-
117-
0 commit comments