1
1
"""This module contains the IPCWEstimator class, for estimating the time to a particular event"""
2
2
3
3
import logging
4
- from numpy import ceil
5
4
from typing import Any
6
- from tqdm import tqdm
7
5
from uuid import uuid4
8
6
7
+
9
8
import numpy as np
10
9
import pandas as pd
11
10
import statsmodels .formula .api as smf
15
14
16
15
logger = logging .getLogger (__name__ )
17
16
18
- debug_id = "data-50/batch_run_16/00221634_10.csv"
19
-
20
17
21
18
class IPCWEstimator (Estimator ):
22
19
"""
@@ -152,7 +149,7 @@ def setup_fault_t_do(self, individual: pd.DataFrame):
152
149
153
150
if not fault .empty :
154
151
time_fault_observed = (
155
- max (0 , ceil (fault ["time" ].min () / self .timesteps_per_observation ) - 1 )
152
+ max (0 , np . ceil (fault ["time" ].min () / self .timesteps_per_observation ) - 1 )
156
153
) * self .timesteps_per_observation
157
154
individual .loc [individual ["time" ] == time_fault_observed , "fault_t_do" ] = 1
158
155
@@ -195,7 +192,7 @@ def preprocess_data(self):
195
192
196
193
assert (
197
194
self .df .groupby ("id" , sort = False ).apply (lambda x : len (set (x ["fault_time" ])) == 1 ).all ()
198
- ), f "Each individual must have a unique fault time."
195
+ ), "Each individual must have a unique fault time."
199
196
200
197
fault_t_do_df = self .df .groupby ("id" , sort = False )[["id" , "time" , self .status_column ]].apply (
201
198
self .setup_fault_t_do
@@ -263,7 +260,8 @@ def preprocess_data(self):
263
260
(
264
261
(
265
262
individuals ["time" ]
266
- < ceil (individuals ["fault_time" ] / self .timesteps_per_observation ) * self .timesteps_per_observation
263
+ < np .ceil (individuals ["fault_time" ] / self .timesteps_per_observation )
264
+ * self .timesteps_per_observation
267
265
)
268
266
& (~ individuals ["xo_t_do" ].isnull ())
269
267
)
@@ -275,7 +273,7 @@ def preprocess_data(self):
275
273
raise ValueError ("No individuals followed either strategy." )
276
274
self .df = individuals .loc [
277
275
individuals ["time" ]
278
- < ceil (individuals ["fault_time" ] / self .timesteps_per_observation ) * self .timesteps_per_observation
276
+ < np . ceil (individuals ["fault_time" ] / self .timesteps_per_observation ) * self .timesteps_per_observation
279
277
].reset_index ()
280
278
logger .debug (len (individuals .groupby ("id" )), "individuals" )
281
279
@@ -341,10 +339,7 @@ def estimate_hazard_ratio(self):
341
339
axis = 1 ,
342
340
).min (axis = 1 )
343
341
344
- assert (preprocessed_data ["tin" ] <= preprocessed_data ["tout" ]).all (), (
345
- f"Left before joining\n "
346
- f"{ preprocessed_data .loc [preprocessed_data ['tin' ] >= preprocessed_data ['tout' ], ['id' , 'time' , 'fault_time' , 'tin' , 'tout' ]]} "
347
- )
342
+ assert (preprocessed_data ["tin" ] <= preprocessed_data ["tout" ]).all (), f"Individuals left before joining."
348
343
349
344
preprocessed_data .to_csv ("/home/michael/tmp/preprocessed_data.csv" )
350
345
0 commit comments