2424
2525
2626def _get_nunconf (geo_model ) -> int :
27- return np .count_nonzero ( geo_model ._stack .df .BottomRelation == "Erosion" ) - 2 # TODO -2 n other lith series
27+ return np .count_nonzero (geo_model ._stack .df .BottomRelation == "Erosion" ) - 2 # TODO -2 n other lith series
2828
2929
3030def _get_nfaults (geo_model ) -> int :
3131 return np .count_nonzero (geo_model ._faults .df .isFault )
3232
3333
3434def _get_fault_blocks (geo_model : gp .data .GeoModel ) -> np .ndarray :
35- # n_unconf = _get_nunconf(geo_model)
36- # n_faults = _get_nfaults(geo_model)
37-
3835 fault_blocks = geo_model .solutions .raw_arrays .block_matrix [geo_model .structural_frame .group_is_fault ]
3936 resolution = geo_model .solutions .octrees_output [- 1 ].grid_centers .regular_grid .resolution
4037
@@ -43,7 +40,6 @@ def _get_fault_blocks(geo_model: gp.data.GeoModel) -> np.ndarray:
4340
4441
4542def _get_lith_blocks (geo_model : gp .data .GeoModel ) -> np .ndarray :
46-
4743 lith_blocks = geo_model .solutions .raw_arrays .block_matrix [[not x for x in geo_model .structural_frame .group_is_fault ]]
4844 resolution = geo_model .solutions .octrees_output [- 1 ].grid_centers .regular_grid .resolution
4945
@@ -158,10 +154,8 @@ def _analyze_topology(
158154 fault_shift = fault_matrix_sum .min ()
159155 fault_matrix_sum_shift = fault_matrix_sum - fault_shift
160156
161- where = np .tile (lith_matrix , (n_lith , 1 )) == np .unique (lith_matrix ).reshape (
162- - 1 , 1 )
163- lith_matrix_shift = np .sum (where * np .arange (n_lith ).reshape (- 1 , 1 ),
164- axis = 0 ) + 1
157+ where = np .tile (lith_matrix , (n_lith , 1 )) == np .unique (lith_matrix ).reshape (- 1 , 1 )
158+ lith_matrix_shift = np .sum (where * np .arange (n_lith ).reshape (- 1 , 1 ), axis = 0 ) + 1
165159
166160 topo_matrix = lith_matrix_shift + n_lith * fault_matrix_sum_shift
167161 topo_matrix_3D = topo_matrix .reshape (* res )
@@ -193,19 +187,14 @@ def _analyze_topology(
193187 else :
194188 z_edges = np .array ([[], []])
195189
196- edges = np .unique (
197- np .concatenate ((x_edges .T , y_edges .T , z_edges .T ), axis = 0 ), axis = 0
198- )
190+ edges = np .unique (np .concatenate ((x_edges .T , y_edges .T , z_edges .T ), axis = 0 ), axis = 0 )
199191
200192 centroids = _get_centroids (topo_matrix_3D )
201193
202194 return edges , centroids
203195
204196
205- def get_lot_node_to_lith_id (
206- geo_model ,
207- centroids : Dict [int , np .ndarray ]
208- ) -> Dict [int , int ]:
197+ def get_lot_node_to_lith_id (geo_model , centroids : Dict [int , np .ndarray ]) -> Dict [int , int ]:
209198 """Get look-up table to translate topology node id's back into GemPy lith
210199 id's.
211200
@@ -216,9 +205,8 @@ def get_lot_node_to_lith_id(
216205 Returns:
217206 Dict[int, int]: Look-up table translating node id -> lith id.
218207 """
219- lb = geo_model .solutions .lith_block .reshape (
220- geo_model ._grid .regular_grid .resolution
221- ).astype (int )
208+ resolution = geo_model .solutions .octrees_output [- 1 ].grid_centers .regular_grid .resolution
209+ lb = geo_model .solutions .raw_arrays .lith_block .reshape (resolution ).astype (int )
222210
223211 lot = {}
224212 for node , pos in centroids .items ():
@@ -228,9 +216,7 @@ def get_lot_node_to_lith_id(
228216 return lot
229217
230218
231- def get_lot_lith_to_node_id (
232- lot : Dict [int , np .ndarray ]
233- ) -> Dict [int , List [int ]]:
219+ def get_lot_lith_to_node_id (lot : Dict [int , np .ndarray ]) -> Dict [int , List [int ]]:
234220 """Get look-up table to translate lith id's back into topology node
235221 id's.
236222
@@ -250,10 +236,7 @@ def get_lot_lith_to_node_id(
250236 return lot2
251237
252238
253- def get_lot_node_to_fault_block (
254- geo_model ,
255- centroids : Dict [int , np .ndarray ]
256- ) -> Dict [int , int ]:
239+ def get_lot_node_to_fault_block ( geo_model , centroids : Dict [int , np .ndarray ] ) -> Dict [int , int ]:
257240 """Get a look-up table to access fault block id's for each topology node
258241 id.
259242
@@ -280,16 +263,14 @@ def get_fault_ids(geo_model) -> List[int]:
280263 Returns:
281264 List[int]: List of fault id's.
282265 """
283- f_series_names = geo_model ._faults .df [geo_model ._faults .df .isFault ].index
284- fault_ids = [0 ]
285- for fsn in f_series_names :
286- fid = geo_model ._surfaces .df [
287- geo_model ._surfaces .df .series == fsn ].id .values [0 ]
288- fault_ids .append (fid )
266+ group_is_fault : list [bool ] = geo_model .structural_frame .group_is_fault
267+ n_faults = np .sum (group_is_fault )
268+ fault_ids = [i for i in range (n_faults + 1 )]
269+
289270 return fault_ids
290271
291272
292- def get_lith_ids (geo_model , basement : bool = True ) -> List [int ]:
273+ def get_lith_ids (geo_model : gp . data . GeoModel ) -> List [int ]:
293274 """ Get lithology id's of all lithologies (except basement) in given
294275 geomodel.
295276
@@ -299,16 +280,13 @@ def get_lith_ids(geo_model, basement: bool = True) -> List[int]:
299280 Returns:
300281 List[int]: List of lithology id's.
301282 """
302- fmt_series_names = geo_model ._faults .df [~ geo_model ._faults .df .isFault ].index
303- lith_ids = []
304- for fsn in fmt_series_names :
305- if not basement :
306- if fsn == "Basement" :
307- continue
308- lids = geo_model ._surfaces .df [
309- geo_model ._surfaces .df .series == fsn ].id .values
310- for lid in lids :
311- lith_ids .append (lid )
283+ # ! This is only working assuming that the faults are on top
284+ group_is_fault : list [bool ] = geo_model .structural_frame .group_is_fault
285+ n_elements = geo_model .structural_frame .n_elements
286+ n_faults = np .sum (group_is_fault )
287+
288+ lith_ids = [i for i in range (n_faults + 1 , n_elements + 1 )]
289+
312290 return lith_ids
313291
314292
@@ -349,10 +327,7 @@ def get_detailed_labels(
349327 return edges_ , centroids_
350328
351329
352- def _get_edges (
353- l : np .ndarray ,
354- r : np .ndarray
355- ) -> Optional [np .ndarray ]:
330+ def _get_edges ( l : np .ndarray , r : np .ndarray ) -> Optional [np .ndarray ]:
356331 """Get edges from given shifted arrays.
357332
358333 Args:
@@ -515,8 +490,7 @@ def plot_adjacency_matrix(
515490 n_faults = len (f_ids ) // 2
516491 lith_ids = get_lith_ids (geo_model )
517492 n_liths = len (lith_ids )
518- adj_matrix_labels , adj_matrix_lith_labels , adj_matrix_fault_labels = _get_adj_matrix_labels (
519- geo_model )
493+ adj_matrix_labels , adj_matrix_lith_labels , adj_matrix_fault_labels = _get_adj_matrix_labels (geo_model )
520494 # ///////////////////////////////////////////////////////
521495 n = len (adj_matrix_labels )
522496 fig , ax = plt .subplots (figsize = (n // 2.5 , n // 2.5 ))
@@ -536,13 +510,10 @@ def plot_adjacency_matrix(
536510
537511 # ///////////////////////////////////////////////////////
538512 # lith tick labels colors
539- colors = list (geo_model ._surfaces .colors .colordict .values ())
540- bboxkwargs = dict (
541- edgecolor = 'none' ,
542- )
543- for xticklabel , yticklabel , l in zip (ax .xaxis .get_ticklabels (),
544- ax .yaxis .get_ticklabels (),
545- adj_matrix_labels [::1 ]):
513+ colors = geo_model .structural_frame .elements_colors
514+ # colors = list(geo_model._surfaces.colors.colordict.values())
515+ bboxkwargs = dict (edgecolor = 'none' , )
516+ for xticklabel , yticklabel , l in zip (ax .xaxis .get_ticklabels (), ax .yaxis .get_ticklabels (), adj_matrix_labels [::1 ]):
546517 color = colors [l [0 ] - 1 ]
547518
548519 xticklabel .set_bbox (
@@ -569,8 +540,7 @@ def plot_adjacency_matrix(
569540 newax .spines ['left' ].set_position (('outward' , 25 ))
570541 newax .set_ylim (0 , n_faults * 2 )
571542 newax .set_yticks (np .arange (1 , n_faults * 2 + 1 ) - 0.5 )
572- newax .set_yticklabels (
573- ["FB " + str (i + 1 ) for i in range (n_faults * 2 )][::1 ])
543+ newax .set_yticklabels ( ["FB " + str (i + 1 ) for i in range (n_faults * 2 )][::1 ])
574544
575545 # ///////////////////////////////////////////////////////
576546 # (dotted) lines for fb's
@@ -601,11 +571,7 @@ def plot_adjacency_matrix(
601571 return
602572
603573
604- def check_adjacency (
605- edges : set ,
606- n1 : Union [int , str ],
607- n2 : Union [int , str ]
608- ) -> bool :
574+ def check_adjacency ( edges : set , n1 : Union [int , str ], n2 : Union [int , str ] ) -> bool :
609575 """Check if given nodes n1 and n2 are adjacent in given topology
610576 edge set.
611577
@@ -623,10 +589,7 @@ def check_adjacency(
623589 return False
624590
625591
626- def get_adjacencies (
627- edges : set ,
628- node : Union [int , str ]
629- ) -> set :
592+ def get_adjacencies ( edges : set , node : Union [int , str ] ) -> set :
630593 """Get node labels of all adjacent geobodies of geobody with given node
631594 in given set of edges.
632595
0 commit comments