88from gempy_engine .modules .kernel_constructor .kernel_constructor_interface import yield_evaluation_grad_kernel , yield_evaluation_kernel
99
1010
11- def generic_evaluator (solver_input : SolverInput , weights : np .ndarray , options : InterpolationOptions ) -> ExportedFields :
11+ def generic_evaluator (
12+ solver_input : SolverInput ,
13+ weights : np .ndarray ,
14+ options : InterpolationOptions
15+ ) -> ExportedFields :
1216 grid_size = solver_input .xyz_to_interpolate .shape [0 ]
13- matrix_size = grid_size * weights .shape [0 ]
14- scalar_field : np .ndarray = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
17+ max_op_size = options .evaluation_chunk_size
18+ num_weights = weights .shape [0 ]
19+
20+
21+ chunk_size_grid = max (1 , int (max_op_size / num_weights )) # Ensure at least 1 point per chunk
22+ n_chunks = int (np .ceil (grid_size / chunk_size_grid ))
23+
24+ # Pre‑allocate outputs
25+ scalar_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
1526 gx_field : Optional [np .ndarray ] = None
1627 gy_field : Optional [np .ndarray ] = None
1728 gz_field : Optional [np .ndarray ] = None
18- gradient = options .compute_scalar_gradient
19-
20- # * Chunking the evaluation
21- max_size = options .evaluation_chunk_size
22- n_chunks = int (np .ceil (matrix_size / max_size ))
23- chunk_size = int (np .ceil (grid_size / n_chunks ))
24- for i in range (n_chunks ): # TODO: It seems the chunking is not properly implemented
25- slice_array = slice (i * chunk_size , (i + 1 ) * chunk_size )
26- scalar_field_chunk , gx_field_chunk , gy_field_chunk , gz_field_chunk = _eval_on (
29+ if options .compute_scalar_gradient :
30+ gx_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
31+ gy_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
32+ if options .number_dimensions == 3 :
33+ gz_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
34+
35+ # Chunked evaluation over grid indices
36+ for i in range (n_chunks ):
37+
38+ start = i * chunk_size_grid
39+ end = min (grid_size , start + chunk_size_grid ) # Ensure 'end' doesn't exceed grid_size
40+ slice_array = slice (start , end )
41+
42+ # Avoid processing empty slices if start == end
43+ if start >= end :
44+ continue
45+
46+ sf_chunk , gx_chunk , gy_chunk , gz_chunk = _eval_on (
2747 solver_input = solver_input ,
2848 weights = weights ,
2949 options = options ,
3050 slice_array = slice_array
3151 )
3252
33- scalar_field [slice_array ] = scalar_field_chunk
34- if gradient is True :
35- if i == 0 :
36- gx_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
37- gy_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
38- gz_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
39-
40- gx_field [slice_array ] = gx_field_chunk
41- gy_field [slice_array ] = gy_field_chunk
42- gz_field [slice_array ] = gz_field_chunk
53+ scalar_field [slice_array ] = sf_chunk
54+ if options .compute_scalar_gradient :
55+ gx_field [slice_array ] = gx_chunk # type: ignore
56+ gy_field [slice_array ] = gy_chunk # type: ignore
57+ if gz_field is not None :
58+ gz_field [slice_array ] = gz_chunk # type: ignore
4359
4460 if n_chunks > 5 :
4561 print (f"Chunking done: { n_chunks } chunks" )
4662
4763 return ExportedFields (scalar_field , gx_field , gy_field , gz_field )
4864
4965
50- def _eval_on (solver_input , weights , options , slice_array : slice = None ):
51- eval_kernel = yield_evaluation_kernel (solver_input , options .kernel_options , slice_array = slice_array )
52- scalar_field : np .ndarray = (eval_kernel .T @ weights ).reshape (- 1 )
53- scalar_field [- 50 :]
66+ def _eval_on (
67+ solver_input : SolverInput ,
68+ weights : np .ndarray ,
69+ options : InterpolationOptions ,
70+ slice_array : slice
71+ ):
72+ eval_kernel = yield_evaluation_kernel (
73+ solver_input , options .kernel_options , slice_array = slice_array
74+ )
75+ scalar_field = (eval_kernel .T @ weights ).reshape (- 1 )
76+
5477 gx_field : Optional [np .ndarray ] = None
5578 gy_field : Optional [np .ndarray ] = None
5679 gz_field : Optional [np .ndarray ] = None
57- if options .compute_scalar_gradient is True :
58- eval_gx_kernel = yield_evaluation_grad_kernel (solver_input , options .kernel_options , axis = 0 , slice_array = slice_array )
59- eval_gy_kernel = yield_evaluation_grad_kernel (solver_input , options .kernel_options , axis = 1 , slice_array = slice_array )
60- gx_field = (eval_gx_kernel .T @ weights ).reshape (- 1 )
61- gy_field = (eval_gy_kernel .T @ weights ).reshape (- 1 )
80+
81+ if options .compute_scalar_gradient :
82+ eval_gx = yield_evaluation_grad_kernel (
83+ solver_input , options .kernel_options , axis = 0 , slice_array = slice_array
84+ )
85+ eval_gy = yield_evaluation_grad_kernel (
86+ solver_input , options .kernel_options , axis = 1 , slice_array = slice_array
87+ )
88+ gx_field = (eval_gx .T @ weights ).reshape (- 1 )
89+ gy_field = (eval_gy .T @ weights ).reshape (- 1 )
6290
6391 if options .number_dimensions == 3 :
64- eval_gz_kernel = yield_evaluation_grad_kernel (solver_input , options .kernel_options , axis = 2 , slice_array = slice_array )
65- gz_field = (eval_gz_kernel .T @ weights ).reshape (- 1 )
66- elif options .number_dimensions == 2 :
67- gz_field = None
68- else :
69- raise ValueError ("Number of dimensions have to be 2 or 3" )
70- return scalar_field , gx_field , gy_field , gz_field
92+ eval_gz = yield_evaluation_grad_kernel (
93+ solver_input , options .kernel_options , axis = 2 , slice_array = slice_array
94+ )
95+ gz_field = (eval_gz .T @ weights ).reshape (- 1 )
96+ elif options .number_dimensions != 2 :
97+ raise ValueError ("`number_dimensions` must be 2 or 3" )
98+
99+ return scalar_field , gx_field , gy_field , gz_field
0 commit comments