3333from apache_beam .transforms .window import TimestampedValue
3434from apache_beam .utils import timestamp
3535from apache_beam .utils .timestamp import MAX_TIMESTAMP
36+ from apache_beam .utils .timestamp import Duration
3637from apache_beam .utils .timestamp import Timestamp
38+ from apache_beam .utils .timestamp import TimestampTypes
3739
3840
3941class ImpulseSeqGenRestrictionProvider (core .RestrictionProvider ):
4042 def initial_restriction (self , element ):
4143 start , end , interval = element
42- if isinstance (start , Timestamp ):
43- start_micros = start .micros
44- else :
45- start_micros = round (start * 1000000 )
44+ if not isinstance (start , Timestamp ):
45+ start = Timestamp .of (start )
4646
47- if isinstance (end , Timestamp ):
48- end_micros = end .micros
49- else :
50- end_micros = round (end * 1000000 )
47+ if not isinstance (end , Timestamp ):
48+ end = Timestamp .of (end )
5149
52- interval_micros = round (interval * 1000000 )
50+ interval_duration = Duration (interval )
5351
54- assert start_micros <= end_micros
52+ assert start <= end
5553 assert interval > 0
56- delta_micros : int = end_micros - start_micros
57- total_outputs = math .ceil (delta_micros / interval_micros )
54+ total_duration = end - start
55+ total_outputs = math .ceil (total_duration .micros / interval_duration .micros )
56+
5857 return OffsetRange (0 , total_outputs )
5958
6059 def create_tracker (self , restriction ):
@@ -230,38 +229,31 @@ def _validate_and_adjust_duration(self):
230229 assert self .data
231230
232231 # The total time we need to impulse all the data.
233- data_duration = (len (self .data ) - 1 ) * self .interval
232+ data_duration = (len (self .data ) - 1 ) * Duration ( self .interval )
234233
235234 is_pre_timestamped = isinstance (self .data [0 ], tuple ) and \
236235 isinstance (self .data [0 ][0 ], timestamp .Timestamp )
237236
238- if isinstance (self .start_ts , Timestamp ):
239- start = self .start_ts .micros / 1000000
240- else :
241- start = self .start_ts
242-
243- if isinstance (self .stop_ts , Timestamp ):
244- if self .stop_ts == MAX_TIMESTAMP :
245- # When the stop timestamp is unbounded (MAX_TIMESTAMP), set it to the
246- # data's actual end time plus an extra fire interval, because the
247- # impulse duration's upper bound is exclusive.
248- end = start + data_duration + self .interval
249- self .stop_ts = Timestamp (micros = end * 1000000 )
250- else :
251- end = self .stop_ts .micros / 1000000
252- else :
253- end = self .stop_ts
237+ start_ts = Timestamp .of (self .start_ts )
238+ stop_ts = Timestamp .of (self .stop_ts )
239+
240+ if stop_ts == MAX_TIMESTAMP :
241+ # When the stop timestamp is unbounded (MAX_TIMESTAMP), set it to the
242+ # data's actual end time plus an extra fire interval, because the
243+ # impulse duration's upper bound is exclusive.
244+ self .stop_ts = start_ts + data_duration + Duration (self .interval )
245+ stop_ts = self .stop_ts
254246
255247 # The total time for the impulse signal which occurs in [start, end).
256- impulse_duration = end - start
257- if round ( data_duration + self .interval , 6 ) < round ( impulse_duration , 6 ) :
248+ impulse_duration = stop_ts - start_ts
249+ if data_duration + Duration ( self .interval ) < impulse_duration :
258250 # We don't have enough data for the impulse.
259251 # If we can fit at least one more data point in the impulse duration,
260252 # then we will be in the repeat mode.
261253 message = 'The number of elements in the provided pre-timestamped ' \
262254 'data sequence is not enough to span the full impulse duration. ' \
263- f'Expected duration: { impulse_duration :.6f } , ' \
264- f'actual data duration: { data_duration :.6f } .'
255+ f'Expected duration: { impulse_duration } , ' \
256+ f'actual data duration: { data_duration } .'
265257
266258 if is_pre_timestamped :
267259 raise ValueError (
@@ -274,8 +266,8 @@ def _validate_and_adjust_duration(self):
274266
275267 def __init__ (
276268 self ,
277- start_timestamp : Timestamp = Timestamp .now (),
278- stop_timestamp : Timestamp = MAX_TIMESTAMP ,
269+ start_timestamp : TimestampTypes = Timestamp .now (),
270+ stop_timestamp : TimestampTypes = MAX_TIMESTAMP ,
279271 fire_interval : float = 360.0 ,
280272 apply_windowing : bool = False ,
281273 data : Optional [Sequence [Any ]] = None ):
@@ -327,11 +319,11 @@ def expand(self, pbegin):
327319 | 'GenSequence' >> beam .ParDo (ImpulseSeqGenDoFn (self .data )))
328320
329321 if not self .data :
330- # This step is only to ensure the current PTransform expansion is
331- # compatible with the previous Beam versions .
332- result = (
333- result
334- | 'MapToTimestamped' >> beam .Map (lambda tt : TimestampedValue ( tt , tt ) ))
322+ # This step is actually an identity transform, because the Timestamped
323+ # values have already been generated in `ImpulseSeqGenDoFn` .
324+ # We keep this step here to prevent the current PeriodicImpulse from
325+ # breaking the compatibility.
326+ result = ( result | 'MapToTimestamped' >> beam .Map (lambda tt : tt ))
335327
336328 if self .apply_windowing :
337329 result = result | 'ApplyWindowing' >> beam .WindowInto (
0 commit comments