@@ -93,6 +93,18 @@ builtin.func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:
9393 return %3 :!torch.vtensor
9494}
9595
96+ // CHECK-LABEL: func @h
97+ // CHECK: torch.aten.conv2d{{.*}} -> !torch.vtensor<[1,16,62,62],f32>
98+ builtin.func @h (%arg0: !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >, %arg1: !torch.vtensor <[16 ,8 ,3 ,3 ],f32 >, %arg2: !torch.vtensor <*,f32 >) ->!torch.vtensor {
99+ %int0 = torch.constant.int 0
100+ %int1 = torch.constant.int 1
101+ %stride = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
102+ %padding = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
103+ %dilation = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
104+ %3 = torch.aten.conv2d %arg0 , %arg1 , %arg2 , %stride , %padding , %dilation , %int1 : !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >, !torch.vtensor <[16 ,8 ,3 ,3 ],f32 >, !torch.vtensor <*,f32 >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.int ->!torch.vtensor
105+ return %3 :!torch.vtensor
106+ }
107+
96108// -----
97109
98110// CHECK-LABEL: func @f
@@ -110,6 +122,70 @@ builtin.func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
110122 return %27 : !torch.vtensor
111123}
112124
125+ // CHECK-LABEL: func @g
126+ builtin.func @g (%arg0: !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >) -> !torch.vtensor {
127+ %int0 = torch.constant.int 0
128+ %int1 = torch.constant.int 1
129+ %int2 = torch.constant.int 2
130+ %int3 = torch.constant.int 3
131+ %bool_false = torch.constant.bool false
132+ %krnl = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
133+ %stride = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
134+ %padding = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
135+ %dilation = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
136+ // CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
137+ %27 = torch.aten.max_pool2d %arg0 , %krnl , %stride , %padding , %dilation , %bool_false : !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.bool -> !torch.vtensor
138+ return %27 : !torch.vtensor
139+ }
140+
141+ // CHECK-LABEL: func @h
142+ builtin.func @h (%arg0: !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >) -> !torch.vtensor {
143+ %int0 = torch.constant.int 0
144+ %int1 = torch.constant.int 1
145+ %int2 = torch.constant.int 2
146+ %int3 = torch.constant.int 3
147+ %bool_false = torch.constant.bool false
148+ %krnl = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
149+ %stride = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
150+ %padding = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
151+ %dilation = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
152+ // CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,62,62],f32>
153+ %27 = torch.aten.max_pool2d %arg0 , %krnl , %stride , %padding , %dilation , %bool_false : !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.bool -> !torch.vtensor
154+ return %27 : !torch.vtensor
155+ }
156+
157+ // CHECK-LABEL: func @i
158+ builtin.func @i (%arg0: !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >) -> !torch.vtensor {
159+ %int0 = torch.constant.int 0
160+ %int1 = torch.constant.int 1
161+ %int2 = torch.constant.int 2
162+ %int3 = torch.constant.int 3
163+ %bool_false = torch.constant.bool false
164+ %krnl = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
165+ %stride = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
166+ %padding = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
167+ %dilation = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
168+ // CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,66,66],f32>
169+ %27 = torch.aten.max_pool2d %arg0 , %krnl , %stride , %padding , %dilation , %bool_false : !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.bool -> !torch.vtensor
170+ return %27 : !torch.vtensor
171+ }
172+
173+ // CHECK-LABEL: func @j
174+ builtin.func @j (%arg0: !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >) -> !torch.vtensor {
175+ %int0 = torch.constant.int 0
176+ %int1 = torch.constant.int 1
177+ %int2 = torch.constant.int 2
178+ %int3 = torch.constant.int 3
179+ %bool_false = torch.constant.bool false
180+ %krnl = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
181+ %stride = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
182+ %padding = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
183+ %dilation = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <!torch.int >
184+ // CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
185+ %27 = torch.aten.max_pool2d %arg0 , %krnl , %stride , %padding , %dilation , %bool_false : !torch.vtensor <[1 ,8 ,64 ,64 ],f32 >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.list <!torch.int >, !torch.bool -> !torch.vtensor
186+ return %27 : !torch.vtensor
187+ }
188+
113189// -----
114190
115191// CHECK-LABEL: func @f
0 commit comments