11import warnings
22
3- from gempy_engine .config import DEBUG_MODE , AvailableBackends
4- from gempy_engine .core .backend_tensor import BackendTensor as bt , BackendTensor
3+ from ...config import DEBUG_MODE , AvailableBackends
4+ from ...core .backend_tensor import BackendTensor as bt , BackendTensor
5+ from ...core .data .exported_fields import ExportedFields
6+ from ._soft_segment import soft_segment_unbounded
7+
58import numpy as np
69import numbers
710
8- from gempy_engine .core .data .exported_fields import ExportedFields
9-
1011
1112def activate_formation_block (exported_fields : ExportedFields , ids : np .ndarray ,
1213 sigmoid_slope : float ) -> np .ndarray :
@@ -23,9 +24,17 @@ def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray,
2324 sigmoid_slope = sigmoid_slope
2425 )
2526 else :
27+ sigm = soft_segment_unbounded (
28+ Z = Z_x ,
29+ edges = scalar_value_at_sp ,
30+ ids = ids ,
31+ sigmoid_slope = sigmoid_slope
32+ )
33+ return sigm
34+
2635 match BackendTensor .engine_backend :
2736 case AvailableBackends .PYTORCH :
28- sigm = soft_segment_unbounded (
37+ sigm = soft_segment_unbounded_torch (
2938 Z = Z_x ,
3039 edges = scalar_value_at_sp ,
3140 ids = ids ,
@@ -85,7 +94,7 @@ def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_
8594import torch
8695
8796
88- def soft_segment_unbounded (Z , edges , ids , sigmoid_slope ):
97+ def soft_segment_unbounded_torch (Z , edges , ids , sigmoid_slope ):
8998 """
9099 Z: (...,) tensor of scalar values
91100 edges: (K-1,) tensor of finite split points [e1, e2, ..., e_{K-1}]
@@ -124,7 +133,7 @@ def soft_segment_unbounded(Z, edges, ids, sigmoid_slope):
124133
125134 # weighted sum by the ids
126135 ids__sum = (membership * ids ).sum (dim = - 1 )
127-
136+
128137 # make it at least 2d
129138 ids__sum = ids__sum [None , :]
130139
@@ -151,10 +160,9 @@ def soft_segment_unbounded_np(Z, edges, ids, sigmoid_slope):
151160 case np .ndarray ():
152161 membership = _final_faults_segmentation (Z , edges , sigmoid_slope )
153162 case numbers .Number ():
154- membership = _lith_segmentation (Z , edges , ids , sigmoid_slope )
163+ membership = _lith_segmentation (Z , edges , ids , sigmoid_slope )
155164 case _:
156- raise ValueError ("sigmoid_slope must be a float or an array" )
157-
165+ raise ValueError ("sigmoid_slope must be a float or an array" )
158166
159167 ids__sum = np .sum (membership * ids , axis = - 1 )
160168 return np .atleast_2d (ids__sum )
@@ -179,7 +187,6 @@ def _final_faults_segmentation(Z, edges, sigmoid_slope):
179187
180188
181189def _lith_segmentation (Z , edges , ids , sigmoid_slope ):
182-
183190 # 1) per-edge temperatures τ_k = |Δ_k|/(4·m)
184191 jumps = np .abs (ids [1 :] - ids [:- 1 ]) # shape (K-1,)
185192 tau_k = jumps / float (sigmoid_slope ) # shape (K-1,)
0 commit comments