@@ -96,7 +96,74 @@ def _process_ihdp_sim_data():
9696 return T , X
9797
9898
99- class StandardDGP ():
99+ class StandardDGP :
100+ """
101+ A class to generate synthetic causal datasets
102+
103+ Parameters
104+ ----------
105+ n: int
106+ Number of observations to generate
107+
108+ d_t: int
109+ Dimensionality of treatment
110+
111+ d_y: int
112+ Dimensionality of outcome
113+
114+ d_x: int
115+ Dimensionality of features
116+
117+ d_z: int
118+ Dimensionality of instrument
119+
120+ discrete_treatment: bool
121+ Dimensionality of treatment
122+
123+ discrete_isntrument: bool
124+ Dimensionality of instrument
125+
126+ squeeze_T: bool
127+ Whether to squeeze the final T array on output
128+
129+ squeeze_Y: bool
130+ Whether to squeeze the final Y array on output
131+
132+ nuisance_Y: func or dict
133+ Nuisance function. Describes how the covariates affect the outcome.
134+ If a function, this function will be used on features X to partially generate Y.
135+ If a dict, must include 'support' and 'degree' keys.
136+
137+ nuisance_T: func or dict
138+ Nuisance function. Describes how the covariates affect the treatment.
139+ If a function, this function will be used on features X to partially generate T.
140+ If a dict, must include 'support' and 'degree' keys.
141+
142+ nuisance_TZ: func or dict
143+ Nuisance function. Describes how the instrument affects the treatment.
144+ If a function, this function will be used on instrument Z to partially generate T.
145+ If a dict, must include 'support' and 'degree' keys.
146+
147+ theta: func or dict
148+ Describes how the features affects the treatment effect heterogenity.
149+ If a function, this function will be used on features X to calculate treatment effect heterogenity.
150+ If a dict, must include 'support' and 'degree' keys.
151+
152+ y_of_t: func or dict
153+ Describes how the treatment affects the outcome.
154+ If a function, this function will be used directly.
155+ If a dict, must include 'support' and 'degree' keys.
156+
157+ x_eps: float
158+ Noise parameter for feature generation
159+
160+ y_eps: func or dict
161+ Noise parameter for outcome generation
162+
163+ t_eps: func or dict
164+ Noise parameter for treatment generation
165+
166+ """
100167 def __init__ (self ,
101168 n = 1000 ,
102169 d_t = 1 ,
@@ -114,7 +181,8 @@ def __init__(self,
114181 y_of_t = None ,
115182 x_eps = 1 ,
116183 y_eps = 1 ,
117- t_eps = 1
184+ t_eps = 1 ,
185+ random_state = None
118186 ):
119187 self .n = n
120188 self .d_t = d_t
@@ -132,15 +200,15 @@ def __init__(self,
132200 else : # else must be dict
133201 if nuisance_Y is None :
134202 nuisance_Y = {'support' : self .d_x , 'degree' : 1 }
135- nuisance_Y [ 'k' ] = self . d_x
203+ assert isinstance ( nuisance_Y , dict ), f"nuisance_Y must be a callable or dict, but got { type ( nuisance_Y ) } "
136204 self .nuisance_Y , self .nuisance_Y_coefs = self .gen_nuisance (** nuisance_Y )
137205
138206 if callable (nuisance_T ):
139207 self .nuisance_T = nuisance_T
140208 else : # else must be dict
141209 if nuisance_T is None :
142210 nuisance_T = {'support' : self .d_x , 'degree' : 1 }
143- nuisance_T [ 'k' ] = self . d_x
211+ assert isinstance ( nuisance_T , dict ), f"nuisance_T must be a callable or dict, but got { type ( nuisance_T ) } "
144212 self .nuisance_T , self .nuisance_T_coefs = self .gen_nuisance (** nuisance_T )
145213
146214 if self .d_z :
@@ -149,7 +217,9 @@ def __init__(self,
149217 else : # else must be dict
150218 if nuisance_TZ is None :
151219 nuisance_TZ = {'support' : self .d_z , 'degree' : 1 }
152- nuisance_TZ ['k' ] = self .d_z
220+ assert isinstance (
221+ nuisance_TZ , dict ), f"nuisance_TZ must be a callable or dict, but got { type (nuisance_TZ )} "
222+ nuisance_TZ = {** nuisance_TZ , 'k' : self .d_z }
153223 self .nuisance_TZ , self .nuisance_TZ_coefs = self .gen_nuisance (** nuisance_TZ )
154224 else :
155225 self .nuisance_TZ = lambda x : 0
@@ -159,14 +229,15 @@ def __init__(self,
159229 else : # else must be dict
160230 if theta is None :
161231 theta = {'support' : self .d_x , 'degree' : 1 , 'bounds' : [1 , 2 ], 'intercept' : True }
162- theta [ 'k' ] = self . d_x
232+ assert isinstance ( theta , dict ), f"theta must be a callable or dict, but got { type ( theta ) } "
163233 self .theta , self .theta_coefs = self .gen_nuisance (** theta )
164234
165235 if callable (y_of_t ):
166236 self .y_of_t = y_of_t
167237 else : # else must be dict
168238 if y_of_t is None :
169239 y_of_t = {'support' : self .d_t , 'degree' : 1 , 'bounds' : [1 , 1 ]}
240+ assert isinstance (y_of_t , dict ), f"y_of_t must be a callable or dict, but got { type (y_of_t )} "
170241 y_of_t ['k' ] = self .d_t
171242 self .y_of_t , self .y_of_t_coefs = self .gen_nuisance (** y_of_t )
172243
@@ -199,9 +270,6 @@ def gen_T(self):
199270 def gen_Z (self ):
200271 if self .d_z :
201272 if self .discrete_instrument :
202- # prob_Z = expit(np.random.normal(size=(self.n, self.d_z)))
203- # self.Z = np.random.binomial(1, prob_Z, size=(self.n, 1))
204- # self.Z = np.random.binomial(1, prob_Z)
205273 self .Z = np .random .binomial (1 , 0.5 , size = (self .n , self .d_z ))
206274 return self .Z
207275
@@ -224,7 +292,6 @@ def gen_nuisance(self, k=None, support=1, bounds=[-1, 1], degree=1, intercept=Fa
224292 mask [supports ] = 1
225293 coefs = coefs * mask
226294
227- # orders = np.random.randint(1, degree, k) if degree!=1 else np.ones(shape=(k,))
228295 orders = np .ones (shape = (k ,)) * degree # enforce all to be the degree for now
229296
230297 if intercept :
0 commit comments