@@ -176,44 +176,3 @@ def policy_gen(Tpre, X, period):
176176 return self ._gen_data_with_policy (n_units , policy_gen , random_seed = random_seed )
177177
178178
179- # Auxiliary function for adding xticks and vertical lines when plotting results
180- # for dynamic dml vs ground truth parameters.
181- def add_vlines (n_periods , n_treatments , hetero_inds ):
182- locs , labels = plt .xticks ([], [])
183- locs += [- .501 + (len (hetero_inds ) + 1 ) / 2 ]
184- labels += ["\n \n $\\ tau_{{{}}}$" .format (0 )]
185- locs += [qx for qx in np .arange (len (hetero_inds ) + 1 )]
186- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
187- for q in np .arange (1 , n_treatments ):
188- plt .axvline (x = q * (len (hetero_inds ) + 1 ) - .5 ,
189- linestyle = '--' , color = 'red' , alpha = .2 )
190- locs += [q * (len (hetero_inds ) + 1 ) - .501 + (len (hetero_inds ) + 1 ) / 2 ]
191- labels += ["\n \n $\\ tau_{{{}}}$" .format (q )]
192- locs += [(q * (len (hetero_inds ) + 1 ) + qx )
193- for qx in np .arange (len (hetero_inds ) + 1 )]
194- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
195- locs += [- .501 + (len (hetero_inds ) + 1 ) * n_treatments / 2 ]
196- labels += ["\n \n \n \n $\\ theta_{{{}}}$" .format (0 )]
197- for t in np .arange (1 , n_periods ):
198- plt .axvline (x = t * (len (hetero_inds ) + 1 ) *
199- n_treatments - .5 , linestyle = '-' , alpha = .6 )
200- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments - .501 +
201- (len (hetero_inds ) + 1 ) * n_treatments / 2 ]
202- labels += ["\n \n \n \n $\\ theta_{{{}}}$" .format (t )]
203- locs += [t * (len (hetero_inds ) + 1 ) *
204- n_treatments - .501 + (len (hetero_inds ) + 1 ) / 2 ]
205- labels += ["\n \n $\\ tau_{{{}}}$" .format (0 )]
206- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments +
207- qx for qx in np .arange (len (hetero_inds ) + 1 )]
208- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
209- for q in np .arange (1 , n_treatments ):
210- plt .axvline (x = t * (len (hetero_inds ) + 1 ) * n_treatments + q * (len (hetero_inds ) + 1 ) - .5 ,
211- linestyle = '--' , color = 'red' , alpha = .2 )
212- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments + q *
213- (len (hetero_inds ) + 1 ) - .501 + (len (hetero_inds ) + 1 ) / 2 ]
214- labels += ["\n \n $\\ tau_{{{}}}$" .format (q )]
215- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments + (q * (len (hetero_inds ) + 1 ) + qx )
216- for qx in np .arange (len (hetero_inds ) + 1 )]
217- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
218- plt .xticks (locs , labels )
219- plt .tight_layout ()
0 commit comments