|
92 | 92 | //! source-level module, functions from the same module will be available for
|
93 | 93 | //! inlining, even when they are not marked `#[inline]`.
|
94 | 94 |
|
| 95 | +// Manuel, fixing rebase |
| 96 | +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; |
| 97 | +//use crate::ty::ParamEnv; |
| 98 | +use rustc_middle::ty::ParamEnv; |
| 99 | + |
95 | 100 | use std::cmp;
|
96 | 101 | use std::collections::hash_map::Entry;
|
97 | 102 | use std::fs::{self, File};
|
@@ -1138,6 +1143,50 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au
|
1138 | 1143 | })
|
1139 | 1144 | .collect();
|
1140 | 1145 |
|
| 1146 | + |
| 1147 | + let autodiff_items = items |
| 1148 | + .iter() |
| 1149 | + .filter_map(|item| match *item { |
| 1150 | + MonoItem::Fn(ref instance) => Some((item, instance)), |
| 1151 | + _ => None, |
| 1152 | + }) |
| 1153 | + .filter_map(|(item, instance)| { |
| 1154 | + let target_id = instance.def_id(); |
| 1155 | + let target_attrs = tcx.autodiff_attrs(target_id); |
| 1156 | + if !target_attrs.apply_autodiff() { |
| 1157 | + return None; |
| 1158 | + } |
| 1159 | + |
| 1160 | + let target_symbol = |
| 1161 | + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); |
| 1162 | + let range = usage_map.index.get(&item).unwrap(); |
| 1163 | + |
| 1164 | + let source = usage_map.targets[range.clone()] |
| 1165 | + .into_iter() |
| 1166 | + .filter_map(|item| match *item { |
| 1167 | + MonoItem::Fn(ref instance_s) => { |
| 1168 | + let source_id = instance_s.def_id(); |
| 1169 | + |
| 1170 | + if tcx.autodiff_attrs(source_id).is_active() { |
| 1171 | + return Some(instance_s); |
| 1172 | + } |
| 1173 | + |
| 1174 | + None |
| 1175 | + } |
| 1176 | + _ => None, |
| 1177 | + }) |
| 1178 | + .next(); |
| 1179 | + |
| 1180 | + source.map(|inst| { |
| 1181 | + let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); |
| 1182 | + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); |
| 1183 | + |
| 1184 | + target_attrs.clone().into_item(symb, target_symbol, inputs, output) |
| 1185 | + }) |
| 1186 | + }); |
| 1187 | + |
| 1188 | + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); |
| 1189 | + |
1141 | 1190 | // Output monomorphization stats per def_id
|
1142 | 1191 | if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
|
1143 | 1192 | if let Err(err) =
|
@@ -1198,7 +1247,163 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au
|
1198 | 1247 | }
|
1199 | 1248 | }
|
1200 | 1249 |
|
1201 |
| - (tcx.arena.alloc(mono_items), codegen_units) |
| 1250 | + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) |
| 1251 | +} |
| 1252 | +use rustc_middle::ty::{self, Adt, ParamEnvAnd, Ty}; |
| 1253 | +use rustc_target::abi::FieldsShape; |
| 1254 | +use std::iter; |
| 1255 | + |
| 1256 | +pub fn typetree_empty() -> TypeTree { |
| 1257 | + TypeTree(vec![]) |
| 1258 | +} |
| 1259 | + |
| 1260 | +pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTree { |
| 1261 | + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { |
| 1262 | + if ty.is_fn_ptr() { |
| 1263 | + unimplemented!("what to do whith fn ptr?"); |
| 1264 | + } |
| 1265 | + |
| 1266 | + let inner_ty = ty.builtin_deref(true).unwrap().ty; |
| 1267 | + let child = typetree_from_ty(inner_ty, tcx, depth + 1); |
| 1268 | + |
| 1269 | + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; |
| 1270 | + //println!("{:depth$} add indirection {:?}", "", tt); |
| 1271 | + |
| 1272 | + return TypeTree(vec![tt]); |
| 1273 | + } |
| 1274 | + |
| 1275 | + if ty.is_scalar() { |
| 1276 | + assert!(!ty.is_any_ptr()); |
| 1277 | + |
| 1278 | + let (kind, size) = if ty.is_integral() { |
| 1279 | + (Kind::Integer, 8) |
| 1280 | + } else { |
| 1281 | + assert!(ty.is_floating_point()); |
| 1282 | + match ty { |
| 1283 | + x if x == tcx.types.f32 => (Kind::Float, 4), |
| 1284 | + x if x == tcx.types.f64 => (Kind::Double, 8), |
| 1285 | + _ => panic!("floatTy scalar that is neither f32 nor f64"), |
| 1286 | + } |
| 1287 | + }; |
| 1288 | + |
| 1289 | + return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); |
| 1290 | + } |
| 1291 | + |
| 1292 | + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; |
| 1293 | + |
| 1294 | + let layout = tcx.layout_of(param_env_and); |
| 1295 | + assert!(layout.is_ok()); |
| 1296 | + |
| 1297 | + let layout = layout.unwrap().layout; |
| 1298 | + let fields = layout.fields(); |
| 1299 | + let max_size = layout.size(); |
| 1300 | + |
| 1301 | + if ty.is_adt() { |
| 1302 | + let adt_def = ty.ty_adt_def().unwrap(); |
| 1303 | + let substs = match ty.kind() { |
| 1304 | + Adt(_, subst_ref) => subst_ref, |
| 1305 | + _ => panic!(""), |
| 1306 | + }; |
| 1307 | + |
| 1308 | + if adt_def.is_struct() { |
| 1309 | + let (offsets, _memory_index) = match fields { |
| 1310 | + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), |
| 1311 | + _ => panic!(""), |
| 1312 | + }; |
| 1313 | + //println!("{:depth$} combine fields", ""); |
| 1314 | + |
| 1315 | + let fields = adt_def.all_fields(); |
| 1316 | + let fields = fields |
| 1317 | + .into_iter() |
| 1318 | + .zip(offsets.into_iter()) |
| 1319 | + .filter_map(|(field, offset)| { |
| 1320 | + let field_ty: Ty<'_> = field.ty(tcx, substs); |
| 1321 | + let field_ty: Ty<'_> = |
| 1322 | + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); |
| 1323 | + |
| 1324 | + if field_ty.is_phantom_data() { |
| 1325 | + return None; |
| 1326 | + } |
| 1327 | + |
| 1328 | + let mut child = typetree_from_ty(field_ty, tcx, depth + 1).0; |
| 1329 | + |
| 1330 | + for c in &mut child { |
| 1331 | + if c.offset == -1 { |
| 1332 | + c.offset = offset.bytes() as isize |
| 1333 | + } else { |
| 1334 | + c.offset += offset.bytes() as isize; |
| 1335 | + } |
| 1336 | + } |
| 1337 | + |
| 1338 | + //inner_tt.offset = offset; |
| 1339 | + |
| 1340 | + //println!("{:depth$} -> {:?}", "", child); |
| 1341 | + |
| 1342 | + Some(child) |
| 1343 | + }) |
| 1344 | + .flatten() |
| 1345 | + .collect::<Vec<Type>>(); |
| 1346 | + |
| 1347 | + let ret_tt = TypeTree(fields); |
| 1348 | + //println!("{:depth$} into {:?}", "", ret_tt); |
| 1349 | + return ret_tt; |
| 1350 | + } else { |
| 1351 | + unimplemented!("adt that isn't a struct"); |
| 1352 | + } |
| 1353 | + } |
| 1354 | + |
| 1355 | + if ty.is_array() { |
| 1356 | + let (stride, count) = match fields { |
| 1357 | + FieldsShape::Array { stride: s, count: c } => (s, c), |
| 1358 | + _ => panic!(""), |
| 1359 | + }; |
| 1360 | + let byte_stride = stride.bytes_usize(); |
| 1361 | + let byte_max_size = max_size.bytes_usize(); |
| 1362 | + |
| 1363 | + assert!(byte_stride * *count as usize == byte_max_size); |
| 1364 | + assert!(*count > 0); // return empty TT for empty? |
| 1365 | + let sub_ty = ty.builtin_index().unwrap(); |
| 1366 | + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); |
| 1367 | + |
| 1368 | + // calculate size of subtree |
| 1369 | + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; |
| 1370 | + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; |
| 1371 | + let tt = TypeTree( |
| 1372 | + iter::repeat(subtt) |
| 1373 | + .take(*count as usize) |
| 1374 | + .enumerate() |
| 1375 | + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) |
| 1376 | + .flatten() |
| 1377 | + .collect(), |
| 1378 | + ); |
| 1379 | + |
| 1380 | + //println!("{:depth$} repeated array into {:?}", "", tt); |
| 1381 | + |
| 1382 | + return tt; |
| 1383 | + } |
| 1384 | + |
| 1385 | + if ty.is_slice() { |
| 1386 | + let sub_ty = ty.builtin_index().unwrap(); |
| 1387 | + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); |
| 1388 | + |
| 1389 | + return subtt; |
| 1390 | + } |
| 1391 | + |
| 1392 | + //println!("Warning: create empty typetree for {}", ty); |
| 1393 | + typetree_empty() |
| 1394 | +} |
| 1395 | + |
| 1396 | +pub fn fnc_typetrees<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> (Vec<TypeTree>, TypeTree) { |
| 1397 | + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); |
| 1398 | + |
| 1399 | + // TODO: verify. |
| 1400 | + let x: ty::FnSig<'_> = fnc_binder.skip_binder(); |
| 1401 | + |
| 1402 | + let inputs = x.inputs().into_iter().map(|x| typetree_from_ty(*x, tcx, 0)).collect(); |
| 1403 | + |
| 1404 | + let output = typetree_from_ty(x.output(), tcx, 0); |
| 1405 | + |
| 1406 | + (inputs, output) |
1202 | 1407 | }
|
1203 | 1408 |
|
1204 | 1409 | /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
|
|
0 commit comments