@@ -43,7 +43,7 @@ class ToDatetimeOp(base_ops.UnaryOp):
43
43
format : typing .Optional [str ] = None
44
44
unit : typing .Optional [str ] = None
45
45
46
- def output_type (self , * input_types ) :
46
+ def output_type (self , * input_types : dtypes . ExpressionType ) -> dtypes . ExpressionType :
47
47
if input_types [0 ] not in (
48
48
dtypes .FLOAT_DTYPE ,
49
49
dtypes .INT_DTYPE ,
@@ -59,7 +59,7 @@ class ToTimestampOp(base_ops.UnaryOp):
59
59
format : typing .Optional [str ] = None
60
60
unit : typing .Optional [str ] = None
61
61
62
- def output_type (self , * input_types ) :
62
+ def output_type (self , * input_types : dtypes . ExpressionType ) -> dtypes . ExpressionType :
63
63
# Must be numeric or string
64
64
if input_types [0 ] not in (
65
65
dtypes .FLOAT_DTYPE ,
@@ -75,29 +75,35 @@ class StrftimeOp(base_ops.UnaryOp):
75
75
name : typing .ClassVar [str ] = "strftime"
76
76
date_format : str
77
77
78
- def output_type (self , * input_types ) :
78
+ def output_type (self , * input_types : dtypes . ExpressionType ) -> dtypes . ExpressionType :
79
79
return dtypes .STRING_DTYPE
80
80
81
81
82
82
@dataclasses .dataclass (frozen = True )
83
83
class UnixSeconds (base_ops .UnaryOp ):
84
84
name : typing .ClassVar [str ] = "unix_seconds"
85
85
86
- def output_type (self , * input_types ):
86
+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
87
+ if input_types [0 ] is not dtypes .TIMESTAMP_DTYPE :
88
+ raise TypeError ("expected timestamp input" )
87
89
return dtypes .INT_DTYPE
88
90
89
91
90
92
@dataclasses .dataclass (frozen = True )
91
93
class UnixMillis (base_ops .UnaryOp ):
92
94
name : typing .ClassVar [str ] = "unix_millis"
93
95
94
- def output_type (self , * input_types ):
96
+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
97
+ if input_types [0 ] is not dtypes .TIMESTAMP_DTYPE :
98
+ raise TypeError ("expected timestamp input" )
95
99
return dtypes .INT_DTYPE
96
100
97
101
98
102
@dataclasses .dataclass (frozen = True )
99
103
class UnixMicros (base_ops .UnaryOp ):
100
104
name : typing .ClassVar [str ] = "unix_micros"
101
105
102
- def output_type (self , * input_types ):
106
+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
107
+ if input_types [0 ] is not dtypes .TIMESTAMP_DTYPE :
108
+ raise TypeError ("expected timestamp input" )
103
109
return dtypes .INT_DTYPE
0 commit comments