Skip to content

Commit 4588f5f

Browse files
committed
Merge branch 'main' into doug/init-hugr-llvm-subtree
2 parents b8413cc + 649589c commit 4588f5f

File tree

17 files changed

+1008
-197
lines changed

17 files changed

+1008
-197
lines changed

hugr-core/src/export.rs

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
type_param::{TypeArgVariable, TypeParam},
88
type_row::TypeRowBase,
99
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
10-
TypeBase, TypeEnum,
10+
TypeBase, TypeBound, TypeEnum,
1111
},
1212
Direction, Hugr, HugrView, IncomingPort, Node, Port,
1313
};
@@ -44,8 +44,21 @@ struct Context<'a> {
4444
bump: &'a Bump,
4545
/// Stores the terms that we have already seen to avoid duplicates.
4646
term_map: FxHashMap<model::Term<'a>, model::TermId>,
47+
4748
/// The current scope for local variables.
49+
///
50+
/// This is set to the id of the smallest enclosing node that defines a polymorphic type.
51+
/// We use this when exporting local variables in terms.
4852
local_scope: Option<model::NodeId>,
53+
54+
/// Constraints to be added to the local scope.
55+
///
56+
/// When exporting a node that defines a polymorphic type, we use this field
57+
/// to collect the constraints that need to be added to that polymorphic
58+
/// type. Currently this is used to record `nonlinear` constraints on uses
59+
/// of `TypeParam::Type` with a `TypeBound::Copyable` bound.
60+
local_constraints: Vec<model::TermId>,
61+
4962
/// Mapping from extension operations to their declarations.
5063
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
5164
}
@@ -63,6 +76,7 @@ impl<'a> Context<'a> {
6376
term_map: FxHashMap::default(),
6477
local_scope: None,
6578
decl_operations: FxHashMap::default(),
79+
local_constraints: Vec::new(),
6680
}
6781
}
6882

@@ -173,9 +187,11 @@ impl<'a> Context<'a> {
173187
}
174188

175189
fn with_local_scope<T>(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
176-
let old_scope = self.local_scope.replace(node);
190+
let prev_local_scope = self.local_scope.replace(node);
191+
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
177192
let result = f(self);
178-
self.local_scope = old_scope;
193+
self.local_scope = prev_local_scope;
194+
self.local_constraints = prev_local_constraints;
179195
result
180196
}
181197

@@ -232,10 +248,11 @@ impl<'a> Context<'a> {
232248

233249
OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
234250
let name = this.get_func_name(node).unwrap();
235-
let (params, signature) = this.export_poly_func_type(&func.signature);
251+
let (params, constraints, signature) = this.export_poly_func_type(&func.signature);
236252
let decl = this.bump.alloc(model::FuncDecl {
237253
name,
238254
params,
255+
constraints,
239256
signature,
240257
});
241258
let extensions = this.export_ext_set(&func.signature.body().extension_reqs);
@@ -247,10 +264,11 @@ impl<'a> Context<'a> {
247264

248265
OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| {
249266
let name = this.get_func_name(node).unwrap();
250-
let (params, func) = this.export_poly_func_type(&func.signature);
267+
let (params, constraints, func) = this.export_poly_func_type(&func.signature);
251268
let decl = this.bump.alloc(model::FuncDecl {
252269
name,
253270
params,
271+
constraints,
254272
signature: func,
255273
});
256274
model::Operation::DeclareFunc { decl }
@@ -450,10 +468,11 @@ impl<'a> Context<'a> {
450468

451469
let decl = self.with_local_scope(node, |this| {
452470
let name = this.make_qualified_name(opdef.extension(), opdef.name());
453-
let (params, r#type) = this.export_poly_func_type(poly_func_type);
471+
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
454472
let decl = this.bump.alloc(model::OperationDecl {
455473
name,
456474
params,
475+
constraints,
457476
r#type,
458477
});
459478
decl
@@ -671,22 +690,36 @@ impl<'a> Context<'a> {
671690
regions.into_bump_slice()
672691
}
673692

693+
/// Exports a polymorphic function type.
694+
///
695+
/// The returned triple consists of:
696+
/// - The static parameters of the polymorphic function type.
697+
/// - The constraints of the polymorphic function type.
698+
/// - The function type itself.
674699
pub fn export_poly_func_type<RV: MaybeRV>(
675700
&mut self,
676701
t: &PolyFuncTypeBase<RV>,
677-
) -> (&'a [model::Param<'a>], model::TermId) {
702+
) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) {
678703
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
704+
let scope = self
705+
.local_scope
706+
.expect("exporting poly func type outside of local scope");
679707

680708
for (i, param) in t.params().iter().enumerate() {
681709
let name = self.bump.alloc_str(&i.to_string());
682-
let r#type = self.export_type_param(param);
683-
let param = model::Param::Implicit { name, r#type };
710+
let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _)));
711+
let param = model::Param {
712+
name,
713+
r#type,
714+
sort: model::ParamSort::Implicit,
715+
};
684716
params.push(param)
685717
}
686718

719+
let constraints = self.bump.alloc_slice_copy(&self.local_constraints);
687720
let body = self.export_func_type(t.body());
688721

689-
(params.into_bump_slice(), body)
722+
(params.into_bump_slice(), constraints, body)
690723
}
691724

692725
pub fn export_type<RV: MaybeRV>(&mut self, t: &TypeBase<RV>) -> model::TermId {
@@ -703,7 +736,6 @@ impl<'a> Context<'a> {
703736
}
704737
TypeEnum::Function(func) => self.export_func_type(func),
705738
TypeEnum::Variable(index, _) => {
706-
// This ignores the type bound for now
707739
let node = self.local_scope.expect("local variable out of scope");
708740
self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _)))
709741
}
@@ -794,20 +826,39 @@ impl<'a> Context<'a> {
794826
self.make_term(model::Term::List { items, tail: None })
795827
}
796828

797-
pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId {
829+
/// Exports a `TypeParam` to a term.
830+
///
831+
/// The `var` argument is set when the type parameter being exported is the
832+
/// type of a parameter to a polymorphic definition. In that case we can
833+
/// generate a `nonlinear` constraint for the type of runtime types marked as
834+
/// `TypeBound::Copyable`.
835+
pub fn export_type_param(
836+
&mut self,
837+
t: &TypeParam,
838+
var: Option<model::LocalRef<'static>>,
839+
) -> model::TermId {
798840
match t {
799-
// This ignores the type bound for now.
800-
TypeParam::Type { .. } => self.make_term(model::Term::Type),
801-
// This ignores the type bound for now.
841+
TypeParam::Type { b } => {
842+
if let (Some(var), TypeBound::Copyable) = (var, b) {
843+
let term = self.make_term(model::Term::Var(var));
844+
let non_linear = self.make_term(model::Term::NonLinearConstraint { term });
845+
self.local_constraints.push(non_linear);
846+
}
847+
848+
self.make_term(model::Term::Type)
849+
}
850+
// This ignores the bound on the natural for now.
802851
TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType),
803852
TypeParam::String => self.make_term(model::Term::StrType),
804853
TypeParam::List { param } => {
805-
let item_type = self.export_type_param(param);
854+
let item_type = self.export_type_param(param, None);
806855
self.make_term(model::Term::ListType { item_type })
807856
}
808857
TypeParam::Tuple { params } => {
809858
let items = self.bump.alloc_slice_fill_iter(
810-
params.iter().map(|param| self.export_type_param(param)),
859+
params
860+
.iter()
861+
.map(|param| self.export_type_param(param, None)),
811862
);
812863
let types = self.make_term(model::Term::List { items, tail: None });
813864
self.make_term(model::Term::ApplyFull {

hugr-core/src/extension/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ lazy_static! {
101101
NoopDef.add_to_extension(&mut prelude).unwrap();
102102
LiftDef.add_to_extension(&mut prelude).unwrap();
103103
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
104+
array::ArrayScanDef.add_to_extension(&mut prelude).unwrap();
104105
prelude
105106
};
106107
/// An extension registry containing only the prelude

0 commit comments

Comments
 (0)